Transfer Learning
Transfer learning leverages knowledge from pre-trained models to improve learning on new tasks
Transfer Learning
Transfer learning leverages knowledge from pre-trained models to improve learning on new tasks.
Pre-trained Models
Loading Pre-trained Models
import torch
import torchvision.models as models
def load_pretrained_model(model_name='resnet50', num_classes=None):
# Load base model
if model_name == 'resnet50':
model = models.resnet50(pretrained=True)
elif model_name == 'vgg16':
model = models.vgg16(pretrained=True)
elif model_name == 'densenet121':
model = models.densenet121(pretrained=True)
# Modify for new task if needed
if num_classes is not None:
if model_name == 'resnet50':
model.fc = nn.Linear(model.fc.in_features, num_classes)
elif model_name == 'vgg16':
model.classifier[-1] = nn.Linear(4096, num_classes)
elif model_name == 'densenet121':
model.classifier = nn.Linear(1024, num_classes)
return model
Feature Extraction
class FeatureExtractor(nn.Module):
def __init__(self, base_model, layer_name):
super().__init__()
self.base_model = base_model
self.layer_name = layer_name
self.features = None
# Register hook
for name, layer in self.base_model.named_modules():
if name == layer_name:
layer.register_forward_hook(self._get_features)
def _get_features(self, module, input, output):
self.features = output
def forward(self, x):
_ = self.base_model(x)
return self.features
def extract_features(model, dataloader):
features = []
labels = []
model.eval()
with torch.no_grad():
for inputs, targets in dataloader:
batch_features = model(inputs)
features.append(batch_features)
labels.append(targets)
return torch.cat(features), torch.cat(labels)
Fine-tuning Strategies
Gradual Unfreezing
class GradualUnfreezing:
def __init__(self, model, num_epochs_per_layer=3):
self.model = model
self.num_epochs_per_layer = num_epochs_per_layer
self.frozen_layers = self._get_layers()
def _get_layers(self):
layers = []
for name, param in self.model.named_parameters():
layer_name = name.split('.')[0]
if layer_name not in layers:
layers.append(layer_name)
return layers[::-1] # Reverse order
def unfreeze_next_layer(self):
if not self.frozen_layers:
return False
layer_to_unfreeze = self.frozen_layers.pop()
for name, param in self.model.named_parameters():
if name.startswith(layer_to_unfreeze):
param.requires_grad = True
return True
def train_with_gradual_unfreezing(model, train_loader, criterion, optimizer, num_epochs):
unfreezer = GradualUnfreezing(model)
current_layer = 0
for epoch in range(num_epochs):
if epoch % unfreezer.num_epochs_per_layer == 0:
if unfreezer.unfreeze_next_layer():
current_layer += 1
# Update optimizer to include newly unfrozen parameters
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters())
)
model.train()
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
Layer-wise Learning Rates
def get_layer_wise_learning_rates(model, base_lr=0.001, decay_factor=0.9):
layer_parameters = []
current_lr = base_lr
# Group parameters by layer
for name, params in model.named_parameters():
layer_name = name.split('.')[0]
layer_parameters.append({
'params': params,
'lr': current_lr
})
current_lr *= decay_factor
return layer_parameters
def create_layer_wise_optimizer(model, base_lr=0.001):
parameters = get_layer_wise_learning_rates(model, base_lr)
optimizer = torch.optim.Adam(parameters)
return optimizer
Domain Adaptation
Adversarial Domain Adaptation
class DomainClassifier(nn.Module):
def __init__(self, feature_dim):
super().__init__()
self.classifier = nn.Sequential(
nn.Linear(feature_dim, 1024),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1024, 2) # Source or target domain
)
def forward(self, x):
return self.classifier(x)
class GradientReversalLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return -ctx.alpha * grad_output, None
def train_domain_adaptation(feature_extractor, task_classifier,
domain_classifier, source_loader, target_loader):
for (source_data, source_labels), (target_data, _) in zip(source_loader, target_loader):
# Extract features
source_features = feature_extractor(source_data)
target_features = feature_extractor(target_data)
# Task classification loss
task_pred = task_classifier(source_features)
task_loss = F.cross_entropy(task_pred, source_labels)
# Domain classification loss
source_domain = torch.zeros(source_features.size(0))
target_domain = torch.ones(target_features.size(0))
domain_features = torch.cat([source_features, target_features])
domain_labels = torch.cat([source_domain, target_domain])
reversed_features = GradientReversalLayer.apply(domain_features, 1.0)
domain_pred = domain_classifier(reversed_features)
domain_loss = F.cross_entropy(domain_pred, domain_labels)
# Total loss
loss = task_loss + domain_loss
loss.backward()
Few-shot Learning
Prototypical Networks
class PrototypicalNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
def forward(self, support_set, query_set, n_way):
# Encode all samples
support_features = self.encoder(support_set)
query_features = self.encoder(query_set)
# Calculate prototypes
prototypes = support_features.view(n_way, -1, support_features.size(-1)).mean(1)
# Calculate distances
distances = torch.cdist(query_features, prototypes)
return -distances # Convert to logits
def train_prototypical(model, episodes_loader, optimizer):
for support_images, support_labels, query_images, query_labels in episodes_loader:
optimizer.zero_grad()
# Forward pass
logits = model(support_images, query_images, n_way=5)
loss = F.cross_entropy(logits, query_labels)
# Backward pass
loss.backward()
optimizer.step()
Best Practices
- Model Selection
def select_pretrained_model(task_type, dataset_size, num_classes):
if dataset_size < 1000:
# Small dataset: Use feature extraction
model = load_pretrained_model('resnet50')
for param in model.parameters():
param.requires_grad = False
elif dataset_size < 10000:
# Medium dataset: Fine-tune last few layers
model = load_pretrained_model('resnet50')
for param in model.parameters():
param.requires_grad = False
for param in model.layer4.parameters():
param.requires_grad = True
else:
# Large dataset: Fine-tune entire model
model = load_pretrained_model('resnet50')
return model
- Learning Rate Selection
def get_learning_rate_schedule(optimizer, num_epochs):
return torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=0.01,
epochs=num_epochs,
steps_per_epoch=100,
pct_start=0.3
)
- Validation Strategy
def validate_transfer_learning(model, val_loader, source_val_loader):
# Evaluate on target domain
target_metrics = evaluate_model(model, val_loader)
# Evaluate on source domain
source_metrics = evaluate_model(model, source_val_loader)
# Check for catastrophic forgetting
forgetting = source_metrics['accuracy'] - target_metrics['accuracy']
return {
'target_metrics': target_metrics,
'source_metrics': source_metrics,
'forgetting': forgetting
}