Transfer Learning

Transfer learning leverages knowledge from pre-trained models to improve learning on new tasks

Transfer Learning

Transfer learning leverages knowledge from pre-trained models to improve learning on new tasks.

Pre-trained Models

Loading Pre-trained Models

import torch
import torchvision.models as models

def load_pretrained_model(model_name='resnet50', num_classes=None):
    # Load base model
    if model_name == 'resnet50':
        model = models.resnet50(pretrained=True)
    elif model_name == 'vgg16':
        model = models.vgg16(pretrained=True)
    elif model_name == 'densenet121':
        model = models.densenet121(pretrained=True)
    
    # Modify for new task if needed
    if num_classes is not None:
        if model_name == 'resnet50':
            model.fc = nn.Linear(model.fc.in_features, num_classes)
        elif model_name == 'vgg16':
            model.classifier[-1] = nn.Linear(4096, num_classes)
        elif model_name == 'densenet121':
            model.classifier = nn.Linear(1024, num_classes)
    
    return model

Feature Extraction

class FeatureExtractor(nn.Module):
    def __init__(self, base_model, layer_name):
        super().__init__()
        self.base_model = base_model
        self.layer_name = layer_name
        self.features = None
        
        # Register hook
        for name, layer in self.base_model.named_modules():
            if name == layer_name:
                layer.register_forward_hook(self._get_features)
    
    def _get_features(self, module, input, output):
        self.features = output
    
    def forward(self, x):
        _ = self.base_model(x)
        return self.features

def extract_features(model, dataloader):
    features = []
    labels = []
    
    model.eval()
    with torch.no_grad():
        for inputs, targets in dataloader:
            batch_features = model(inputs)
            features.append(batch_features)
            labels.append(targets)
    
    return torch.cat(features), torch.cat(labels)

Fine-tuning Strategies

Gradual Unfreezing

class GradualUnfreezing:
    def __init__(self, model, num_epochs_per_layer=3):
        self.model = model
        self.num_epochs_per_layer = num_epochs_per_layer
        self.frozen_layers = self._get_layers()
        
    def _get_layers(self):
        layers = []
        for name, param in self.model.named_parameters():
            layer_name = name.split('.')[0]
            if layer_name not in layers:
                layers.append(layer_name)
        return layers[::-1]  # Reverse order
    
    def unfreeze_next_layer(self):
        if not self.frozen_layers:
            return False
        
        layer_to_unfreeze = self.frozen_layers.pop()
        for name, param in self.model.named_parameters():
            if name.startswith(layer_to_unfreeze):
                param.requires_grad = True
        
        return True

def train_with_gradual_unfreezing(model, train_loader, criterion, optimizer, num_epochs):
    unfreezer = GradualUnfreezing(model)
    current_layer = 0
    
    for epoch in range(num_epochs):
        if epoch % unfreezer.num_epochs_per_layer == 0:
            if unfreezer.unfreeze_next_layer():
                current_layer += 1
                # Update optimizer to include newly unfrozen parameters
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, model.parameters())
                )
        
        model.train()
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

Layer-wise Learning Rates

def get_layer_wise_learning_rates(model, base_lr=0.001, decay_factor=0.9):
    layer_parameters = []
    current_lr = base_lr
    
    # Group parameters by layer
    for name, params in model.named_parameters():
        layer_name = name.split('.')[0]
        layer_parameters.append({
            'params': params,
            'lr': current_lr
        })
        current_lr *= decay_factor
    
    return layer_parameters

def create_layer_wise_optimizer(model, base_lr=0.001):
    parameters = get_layer_wise_learning_rates(model, base_lr)
    optimizer = torch.optim.Adam(parameters)
    return optimizer

Domain Adaptation

Adversarial Domain Adaptation

class DomainClassifier(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 2)  # Source or target domain
        )
    
    def forward(self, x):
        return self.classifier(x)

class GradientReversalLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.alpha * grad_output, None

def train_domain_adaptation(feature_extractor, task_classifier, 
                          domain_classifier, source_loader, target_loader):
    for (source_data, source_labels), (target_data, _) in zip(source_loader, target_loader):
        # Extract features
        source_features = feature_extractor(source_data)
        target_features = feature_extractor(target_data)
        
        # Task classification loss
        task_pred = task_classifier(source_features)
        task_loss = F.cross_entropy(task_pred, source_labels)
        
        # Domain classification loss
        source_domain = torch.zeros(source_features.size(0))
        target_domain = torch.ones(target_features.size(0))
        domain_features = torch.cat([source_features, target_features])
        domain_labels = torch.cat([source_domain, target_domain])
        
        reversed_features = GradientReversalLayer.apply(domain_features, 1.0)
        domain_pred = domain_classifier(reversed_features)
        domain_loss = F.cross_entropy(domain_pred, domain_labels)
        
        # Total loss
        loss = task_loss + domain_loss
        loss.backward()

Few-shot Learning

Prototypical Networks

class PrototypicalNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
    
    def forward(self, support_set, query_set, n_way):
        # Encode all samples
        support_features = self.encoder(support_set)
        query_features = self.encoder(query_set)
        
        # Calculate prototypes
        prototypes = support_features.view(n_way, -1, support_features.size(-1)).mean(1)
        
        # Calculate distances
        distances = torch.cdist(query_features, prototypes)
        
        return -distances  # Convert to logits

def train_prototypical(model, episodes_loader, optimizer):
    for support_images, support_labels, query_images, query_labels in episodes_loader:
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(support_images, query_images, n_way=5)
        loss = F.cross_entropy(logits, query_labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()

Best Practices

  1. Model Selection
def select_pretrained_model(task_type, dataset_size, num_classes):
    if dataset_size < 1000:
        # Small dataset: Use feature extraction
        model = load_pretrained_model('resnet50')
        for param in model.parameters():
            param.requires_grad = False
    elif dataset_size < 10000:
        # Medium dataset: Fine-tune last few layers
        model = load_pretrained_model('resnet50')
        for param in model.parameters():
            param.requires_grad = False
        for param in model.layer4.parameters():
            param.requires_grad = True
    else:
        # Large dataset: Fine-tune entire model
        model = load_pretrained_model('resnet50')
    
    return model
  1. Learning Rate Selection
def get_learning_rate_schedule(optimizer, num_epochs):
    return torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=0.01,
        epochs=num_epochs,
        steps_per_epoch=100,
        pct_start=0.3
    )
  1. Validation Strategy
def validate_transfer_learning(model, val_loader, source_val_loader):
    # Evaluate on target domain
    target_metrics = evaluate_model(model, val_loader)
    
    # Evaluate on source domain
    source_metrics = evaluate_model(model, source_val_loader)
    
    # Check for catastrophic forgetting
    forgetting = source_metrics['accuracy'] - target_metrics['accuracy']
    
    return {
        'target_metrics': target_metrics,
        'source_metrics': source_metrics,
        'forgetting': forgetting
    }