Meta-Learning
Understanding and implementing meta-learning algorithms and techniques
Meta-learning, or "learning to learn," focuses on designing models that can learn new tasks quickly with minimal data.
Model-Agnostic Meta-Learning (MAML)
MAML Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import higher
class MAMLModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
class MAML:
def __init__(
self,
model,
inner_lr=0.01,
meta_lr=0.001,
num_inner_steps=5
):
self.model = model
self.inner_lr = inner_lr
self.meta_optimizer = torch.optim.Adam(
model.parameters(),
lr=meta_lr
)
self.num_inner_steps = num_inner_steps
def inner_loop(self, support_x, support_y):
# Create a differentiable optimizer for inner loop
with higher.innerloop_ctx(
self.model,
torch.optim.SGD(self.model.parameters(), lr=self.inner_lr),
copy_initial_weights=False
) as (fmodel, diffopt):
# Perform inner loop optimization
for _ in range(self.num_inner_steps):
support_pred = fmodel(support_x)
support_loss = F.cross_entropy(support_pred, support_y)
diffopt.step(support_loss)
return fmodel
def outer_loop(self, tasks):
self.meta_optimizer.zero_grad()
meta_loss = 0
for task in tasks:
support_x, support_y = task['support']
query_x, query_y = task['query']
# Perform inner loop adaptation
adapted_model = self.inner_loop(support_x, support_y)
# Evaluate on query set
query_pred = adapted_model(query_x)
query_loss = F.cross_entropy(query_pred, query_y)
meta_loss += query_loss
# Update meta-parameters
meta_loss = meta_loss / len(tasks)
meta_loss.backward()
self.meta_optimizer.step()
return meta_loss.item()
Prototypical Networks
Enhanced Prototypical Network
class EnhancedProtoNet(nn.Module):
def __init__(self, input_dim, hidden_dim, embedding_dim):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.BatchNorm1d(hidden_dim),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.BatchNorm1d(hidden_dim),
nn.Linear(hidden_dim, embedding_dim)
)
def compute_prototypes(self, support_set, support_labels):
# Get embeddings
embeddings = self.encoder(support_set)
# Compute class prototypes
unique_labels = torch.unique(support_labels)
prototypes = []
for label in unique_labels:
mask = support_labels == label
class_embeddings = embeddings[mask]
prototype = class_embeddings.mean(dim=0)
prototypes.append(prototype)
return torch.stack(prototypes)
def forward(self, support_set, support_labels, query_set):
# Compute prototypes
prototypes = self.compute_prototypes(support_set, support_labels)
# Encode query samples
query_embeddings = self.encoder(query_set)
# Compute distances
distances = torch.cdist(query_embeddings, prototypes)
# Convert distances to probabilities
logits = -distances
return F.log_softmax(logits, dim=1)
Relation Networks
class RelationNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
# Feature encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Relation module
self.relation = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, support_set, query_set):
# Encode support and query sets
support_features = self.encoder(support_set)
query_features = self.encoder(query_set)
# Compute relations
n_queries = query_features.size(0)
n_support = support_features.size(0)
# Combine features
support_features_ext = support_features.unsqueeze(0).repeat(n_queries, 1, 1)
query_features_ext = query_features.unsqueeze(1).repeat(1, n_support, 1)
# Concatenate feature pairs
paired_features = torch.cat([support_features_ext, query_features_ext], dim=2)
# Compute relation scores
relations = self.relation(paired_features.view(-1, paired_features.size(2)))
return relations.view(n_queries, n_support)
Memory-Augmented Neural Networks
class MANNController(nn.Module):
def __init__(self, input_dim, hidden_dim, memory_size, memory_dim):
super().__init__()
self.hidden_dim = hidden_dim
self.memory_size = memory_size
self.memory_dim = memory_dim
# Controller network
self.controller = nn.LSTM(
input_dim,
hidden_dim,
batch_first=True
)
# Memory read/write heads
self.read_head = nn.Linear(hidden_dim, memory_size)
self.write_head = nn.Linear(hidden_dim, memory_dim)
def forward(self, x, memory):
# Process input through controller
controller_out, (h_n, c_n) = self.controller(x)
# Generate read weights
read_weights = F.softmax(self.read_head(controller_out), dim=2)
# Read from memory
memory_read = torch.bmm(read_weights, memory)
# Generate write content
write_content = self.write_head(controller_out)
# Update memory
memory = memory + torch.bmm(
read_weights.transpose(1, 2),
write_content
)
return memory_read, memory
Task Representation Learning
class TaskEncoder(nn.Module):
def __init__(self, input_dim, hidden_dim, task_embedding_dim):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, task_embedding_dim)
)
def forward(self, support_set, support_labels):
# Encode support set
support_features = self.encoder(support_set)
# Aggregate task representation
task_embedding = torch.cat([
support_features.mean(dim=0),
support_features.std(dim=0)
])
return task_embedding
class TaskConditionedNetwork(nn.Module):
def __init__(self, input_dim, task_embedding_dim, hidden_dim, output_dim):
super().__init__()
self.task_encoder = TaskEncoder(
input_dim,
hidden_dim,
task_embedding_dim
)
self.network = nn.Sequential(
nn.Linear(input_dim + task_embedding_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x, support_set, support_labels):
# Get task embedding
task_embedding = self.task_encoder(support_set, support_labels)
# Concatenate input with task embedding
x_with_task = torch.cat([
x,
task_embedding.expand(x.size(0), -1)
], dim=1)
return self.network(x_with_task)
Meta-Learning Optimization
Meta-SGD
class MetaSGD(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.learning_rates = nn.ParameterDict({
name: nn.Parameter(torch.ones_like(param))
for name, param in model.named_parameters()
})
def adapt(self, loss, num_steps=1):
for _ in range(num_steps):
grads = torch.autograd.grad(
loss,
self.model.parameters(),
create_graph=True
)
for (name, param), grad in zip(
self.model.named_parameters(),
grads
):
param.data = param - self.learning_rates[name] * grad
Best Practices
- Task Generation
def create_task_batch(dataset, n_way, n_shot, n_query):
# Sample classes
classes = np.random.choice(
len(dataset.classes),
n_way,
replace=False
)
tasks = []
for class_idx in classes:
# Get class samples
class_samples = dataset.get_class_samples(class_idx)
# Split into support and query
support_samples = class_samples[:n_shot]
query_samples = class_samples[n_shot:n_shot+n_query]
tasks.append({
'support': support_samples,
'query': query_samples
})
return tasks
- Evaluation Protocol
def evaluate_meta_learner(model, test_tasks, n_test_episodes=100):
accuracies = []
for _ in range(n_test_episodes):
episode_accuracies = []
for task in test_tasks:
support_x, support_y = task['support']
query_x, query_y = task['query']
# Adapt to task
adapted_model = model.adapt(support_x, support_y)
# Evaluate
with torch.no_grad():
predictions = adapted_model(query_x)
acc = (predictions.argmax(dim=1) == query_y).float().mean()
episode_accuracies.append(acc)
accuracies.append(torch.stack(episode_accuracies).mean())
return {
'mean_accuracy': np.mean(accuracies),
'std_accuracy': np.std(accuracies)
}
- Model Selection
def select_meta_learning_algorithm(dataset_size, n_way, n_shot):
if n_shot < 5:
# Few-shot scenario
return PrototypicalNetwork
elif dataset_size < 1000:
# Small meta-training set
return RelationNetwork
else:
# Large meta-training set
return MAML