Generative Models

Understanding and implementing various types of generative models

Generative models learn to generate new data samples that resemble their training data.

Variational Autoencoders (VAEs)

Basic VAE Architecture

import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim=400):
        super(VAE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Latent space
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_var(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

def vae_loss(recon_x, x, mu, log_var):
    # Reconstruction loss
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    
    # KL divergence
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    
    return BCE + KLD

Training VAE

def train_vae(model, train_loader, optimizer, epochs=100):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.view(data.size(0), -1)
            optimizer.zero_grad()
            
            recon_batch, mu, log_var = model(data)
            loss = vae_loss(recon_batch, data, mu, log_var)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')

Generative Adversarial Networks (GANs)

DCGAN Implementation

class Generator(nn.Module):
    def __init__(self, latent_dim, channels=3):
        super(Generator, self).__init__()
        
        self.main = nn.Sequential(
            # Input is latent vector
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            
            # State size: 512 x 4 x 4
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            # State size: 256 x 8 x 8
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # State size: 128 x 16 x 16
            nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.main(x)

class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()
        
        self.main = nn.Sequential(
            # Input is image
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # State size: 64 x 16 x 16
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # State size: 128 x 8 x 8
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # State size: 256 x 4 x 4
            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.main(x).view(-1, 1).squeeze(1)

Training GAN

def train_gan(generator, discriminator, dataloader, num_epochs=100):
    criterion = nn.BCELoss()
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)
            
            # Train Discriminator
            d_optimizer.zero_grad()
            label_real = torch.ones(batch_size)
            label_fake = torch.zeros(batch_size)
            
            # Real images
            output_real = discriminator(real_images)
            d_loss_real = criterion(output_real, label_real)
            
            # Fake images
            noise = torch.randn(batch_size, latent_dim, 1, 1)
            fake_images = generator(noise)
            output_fake = discriminator(fake_images.detach())
            d_loss_fake = criterion(output_fake, label_fake)
            
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            d_optimizer.step()
            
            # Train Generator
            g_optimizer.zero_grad()
            output_fake = discriminator(fake_images)
            g_loss = criterion(output_fake, label_real)
            
            g_loss.backward()
            g_optimizer.step()
            
        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

Flow-based Models

Normalizing Flows

class PlanarFlow(nn.Module):
    def __init__(self, dim):
        super(PlanarFlow, self).__init__()
        self.w = nn.Parameter(torch.randn(1, dim))
        self.u = nn.Parameter(torch.randn(1, dim))
        self.b = nn.Parameter(torch.randn(1))
        
    def forward(self, z):
        # f(z) = z + u * h(w^T z + b)
        activation = F.tanh(torch.mm(z, self.w.t()) + self.b)
        return z + self.u * activation
    
    def log_det_jacobian(self, z):
        activation = F.tanh(torch.mm(z, self.w.t()) + self.b)
        psi = (1 - activation**2) * self.w
        det = 1 + torch.mm(psi, self.u.t())
        return torch.log(torch.abs(det) + 1e-6)

class NormalizingFlow(nn.Module):
    def __init__(self, dim, n_flows=4):
        super(NormalizingFlow, self).__init__()
        self.flows = nn.ModuleList([PlanarFlow(dim) for _ in range(n_flows)])
        
    def forward(self, z):
        log_det_sum = 0
        
        for flow in self.flows:
            log_det_sum += flow.log_det_jacobian(z)
            z = flow(z)
            
        return z, log_det_sum

Diffusion Models

Diffusion Process

class DiffusionModel(nn.Module):
    def __init__(self, n_steps=1000, beta_start=1e-4, beta_end=0.02):
        super().__init__()
        self.n_steps = n_steps
        self.beta = torch.linspace(beta_start, beta_end, n_steps)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        
    def forward_diffusion(self, x_0, t):
        # Sample noise
        epsilon = torch.randn_like(x_0)
        
        # Get alpha_bar for timestep t
        alpha_bar_t = self.alpha_bar[t]
        
        # Forward process
        x_t = torch.sqrt(alpha_bar_t) * x_0 + \
              torch.sqrt(1 - alpha_bar_t) * epsilon
              
        return x_t, epsilon
    
    def reverse_diffusion(self, model, x_t, t):
        # Predict noise
        predicted_noise = model(x_t, t)
        
        # Get parameters for this timestep
        alpha_t = self.alpha[t]
        alpha_bar_t = self.alpha_bar[t]
        beta_t = self.beta[t]
        
        # Reverse process
        x_t_prev = (1 / torch.sqrt(alpha_t)) * \
                   (x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
                   
        if t > 0:
            noise = torch.randn_like(x_t)
            x_t_prev += torch.sqrt(beta_t) * noise
            
        return x_t_prev

Evaluation Metrics

def calculate_inception_score(generated_images, model):
    # Calculate inception score
    scores = []
    for images in generated_images:
        with torch.no_grad():
            preds = F.softmax(model(images), dim=1)
            p_y = preds.mean(dim=0)
            kl_div = preds * (torch.log(preds) - torch.log(p_y))
            scores.append(kl_div.sum(dim=1))
    
    return torch.exp(torch.mean(torch.stack(scores)))

def calculate_fid_score(real_features, generated_features):
    # Calculate statistics
    mu1, sigma1 = real_features.mean(dim=0), torch_cov(real_features)
    mu2, sigma2 = generated_features.mean(dim=0), torch_cov(generated_features)
    
    # Calculate FID
    diff = mu1 - mu2
    covmean = sqrtm(sigma1.mm(sigma2))
    
    return (diff.dot(diff) + torch.trace(sigma1) + 
            torch.trace(sigma2) - 2 * torch.trace(covmean))

Best Practices

  1. Model Architecture
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
  1. Training Stability
def gradient_penalty(discriminator, real_samples, fake_samples):
    # Random interpolation
    alpha = torch.rand(real_samples.size(0), 1, 1, 1)
    interpolates = alpha * real_samples + (1 - alpha) * fake_samples
    interpolates.requires_grad_(True)
    
    # Calculate gradients
    d_interpolates = discriminator(interpolates)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True
    )[0]
    
    # Calculate penalty
    gradients = gradients.view(gradients.size(0), -1)
    gradient_norm = gradients.norm(2, dim=1)
    penalty = ((gradient_norm - 1) ** 2).mean()
    
    return penalty
  1. Sample Generation
def generate_samples(model, n_samples, latent_dim):
    model.eval()
    with torch.no_grad():
        # Sample from latent space
        z = torch.randn(n_samples, latent_dim)
        
        # Generate samples
        samples = model(z)
        
        # Post-process if needed
        samples = (samples + 1) / 2  # Scale from [-1,1] to [0,1]
        
    return samples