Meta-Learning

Understanding and implementing meta-learning algorithms and techniques

Meta-learning, or "learning to learn," focuses on designing models that can learn new tasks quickly with minimal data.

Model-Agnostic Meta-Learning (MAML)

MAML Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import higher

class MAMLModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

class MAML:
    def __init__(
        self,
        model,
        inner_lr=0.01,
        meta_lr=0.001,
        num_inner_steps=5
    ):
        self.model = model
        self.inner_lr = inner_lr
        self.meta_optimizer = torch.optim.Adam(
            model.parameters(),
            lr=meta_lr
        )
        self.num_inner_steps = num_inner_steps
    
    def inner_loop(self, support_x, support_y):
        # Create a differentiable optimizer for inner loop
        with higher.innerloop_ctx(
            self.model,
            torch.optim.SGD(self.model.parameters(), lr=self.inner_lr),
            copy_initial_weights=False
        ) as (fmodel, diffopt):
            # Perform inner loop optimization
            for _ in range(self.num_inner_steps):
                support_pred = fmodel(support_x)
                support_loss = F.cross_entropy(support_pred, support_y)
                diffopt.step(support_loss)
            
            return fmodel
    
    def outer_loop(self, tasks):
        self.meta_optimizer.zero_grad()
        meta_loss = 0
        
        for task in tasks:
            support_x, support_y = task['support']
            query_x, query_y = task['query']
            
            # Perform inner loop adaptation
            adapted_model = self.inner_loop(support_x, support_y)
            
            # Evaluate on query set
            query_pred = adapted_model(query_x)
            query_loss = F.cross_entropy(query_pred, query_y)
            meta_loss += query_loss
        
        # Update meta-parameters
        meta_loss = meta_loss / len(tasks)
        meta_loss.backward()
        self.meta_optimizer.step()
        
        return meta_loss.item()

Prototypical Networks

Enhanced Prototypical Network

class EnhancedProtoNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, embedding_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, embedding_dim)
        )
        
    def compute_prototypes(self, support_set, support_labels):
        # Get embeddings
        embeddings = self.encoder(support_set)
        
        # Compute class prototypes
        unique_labels = torch.unique(support_labels)
        prototypes = []
        
        for label in unique_labels:
            mask = support_labels == label
            class_embeddings = embeddings[mask]
            prototype = class_embeddings.mean(dim=0)
            prototypes.append(prototype)
        
        return torch.stack(prototypes)
    
    def forward(self, support_set, support_labels, query_set):
        # Compute prototypes
        prototypes = self.compute_prototypes(support_set, support_labels)
        
        # Encode query samples
        query_embeddings = self.encoder(query_set)
        
        # Compute distances
        distances = torch.cdist(query_embeddings, prototypes)
        
        # Convert distances to probabilities
        logits = -distances
        return F.log_softmax(logits, dim=1)

Relation Networks

class RelationNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        
        # Feature encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Relation module
        self.relation = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, support_set, query_set):
        # Encode support and query sets
        support_features = self.encoder(support_set)
        query_features = self.encoder(query_set)
        
        # Compute relations
        n_queries = query_features.size(0)
        n_support = support_features.size(0)
        
        # Combine features
        support_features_ext = support_features.unsqueeze(0).repeat(n_queries, 1, 1)
        query_features_ext = query_features.unsqueeze(1).repeat(1, n_support, 1)
        
        # Concatenate feature pairs
        paired_features = torch.cat([support_features_ext, query_features_ext], dim=2)
        
        # Compute relation scores
        relations = self.relation(paired_features.view(-1, paired_features.size(2)))
        return relations.view(n_queries, n_support)

Memory-Augmented Neural Networks

class MANNController(nn.Module):
    def __init__(self, input_dim, hidden_dim, memory_size, memory_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.memory_size = memory_size
        self.memory_dim = memory_dim
        
        # Controller network
        self.controller = nn.LSTM(
            input_dim,
            hidden_dim,
            batch_first=True
        )
        
        # Memory read/write heads
        self.read_head = nn.Linear(hidden_dim, memory_size)
        self.write_head = nn.Linear(hidden_dim, memory_dim)
        
    def forward(self, x, memory):
        # Process input through controller
        controller_out, (h_n, c_n) = self.controller(x)
        
        # Generate read weights
        read_weights = F.softmax(self.read_head(controller_out), dim=2)
        
        # Read from memory
        memory_read = torch.bmm(read_weights, memory)
        
        # Generate write content
        write_content = self.write_head(controller_out)
        
        # Update memory
        memory = memory + torch.bmm(
            read_weights.transpose(1, 2),
            write_content
        )
        
        return memory_read, memory

Task Representation Learning

class TaskEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, task_embedding_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, task_embedding_dim)
        )
        
    def forward(self, support_set, support_labels):
        # Encode support set
        support_features = self.encoder(support_set)
        
        # Aggregate task representation
        task_embedding = torch.cat([
            support_features.mean(dim=0),
            support_features.std(dim=0)
        ])
        
        return task_embedding

class TaskConditionedNetwork(nn.Module):
    def __init__(self, input_dim, task_embedding_dim, hidden_dim, output_dim):
        super().__init__()
        
        self.task_encoder = TaskEncoder(
            input_dim,
            hidden_dim,
            task_embedding_dim
        )
        
        self.network = nn.Sequential(
            nn.Linear(input_dim + task_embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x, support_set, support_labels):
        # Get task embedding
        task_embedding = self.task_encoder(support_set, support_labels)
        
        # Concatenate input with task embedding
        x_with_task = torch.cat([
            x,
            task_embedding.expand(x.size(0), -1)
        ], dim=1)
        
        return self.network(x_with_task)

Meta-Learning Optimization

Meta-SGD

class MetaSGD(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.learning_rates = nn.ParameterDict({
            name: nn.Parameter(torch.ones_like(param))
            for name, param in model.named_parameters()
        })
    
    def adapt(self, loss, num_steps=1):
        for _ in range(num_steps):
            grads = torch.autograd.grad(
                loss,
                self.model.parameters(),
                create_graph=True
            )
            
            for (name, param), grad in zip(
                self.model.named_parameters(),
                grads
            ):
                param.data = param - self.learning_rates[name] * grad

Best Practices

  1. Task Generation
def create_task_batch(dataset, n_way, n_shot, n_query):
    # Sample classes
    classes = np.random.choice(
        len(dataset.classes),
        n_way,
        replace=False
    )
    
    tasks = []
    for class_idx in classes:
        # Get class samples
        class_samples = dataset.get_class_samples(class_idx)
        
        # Split into support and query
        support_samples = class_samples[:n_shot]
        query_samples = class_samples[n_shot:n_shot+n_query]
        
        tasks.append({
            'support': support_samples,
            'query': query_samples
        })
    
    return tasks
  1. Evaluation Protocol
def evaluate_meta_learner(model, test_tasks, n_test_episodes=100):
    accuracies = []
    
    for _ in range(n_test_episodes):
        episode_accuracies = []
        
        for task in test_tasks:
            support_x, support_y = task['support']
            query_x, query_y = task['query']
            
            # Adapt to task
            adapted_model = model.adapt(support_x, support_y)
            
            # Evaluate
            with torch.no_grad():
                predictions = adapted_model(query_x)
                acc = (predictions.argmax(dim=1) == query_y).float().mean()
                episode_accuracies.append(acc)
        
        accuracies.append(torch.stack(episode_accuracies).mean())
    
    return {
        'mean_accuracy': np.mean(accuracies),
        'std_accuracy': np.std(accuracies)
    }
  1. Model Selection
def select_meta_learning_algorithm(dataset_size, n_way, n_shot):
    if n_shot < 5:
        # Few-shot scenario
        return PrototypicalNetwork
    elif dataset_size < 1000:
        # Small meta-training set
        return RelationNetwork
    else:
        # Large meta-training set
        return MAML