Generative Models
Understanding and implementing various types of generative models
Generative models learn to generate new data samples that resemble their training data.
Variational Autoencoders (VAEs)
Basic VAE Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self, input_dim, latent_dim, hidden_dim=400):
super(VAE, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Latent space
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_var = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid()
)
def encode(self, x):
h = self.encoder(x)
return self.fc_mu(h), self.fc_var(h)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
def vae_loss(recon_x, x, mu, log_var):
# Reconstruction loss
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
# KL divergence
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD
Training VAE
def train_vae(model, train_loader, optimizer, epochs=100):
model.train()
for epoch in range(epochs):
total_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(data.size(0), -1)
optimizer.zero_grad()
recon_batch, mu, log_var = model(data)
loss = vae_loss(recon_batch, data, mu, log_var)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader.dataset)
print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')
Generative Adversarial Networks (GANs)
DCGAN Implementation
class Generator(nn.Module):
def __init__(self, latent_dim, channels=3):
super(Generator, self).__init__()
self.main = nn.Sequential(
# Input is latent vector
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# State size: 512 x 4 x 4
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# State size: 256 x 8 x 8
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# State size: 128 x 16 x 16
nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
class Discriminator(nn.Module):
def __init__(self, channels=3):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# Input is image
nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# State size: 64 x 16 x 16
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# State size: 128 x 8 x 8
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# State size: 256 x 4 x 4
nn.Conv2d(256, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x).view(-1, 1).squeeze(1)
Training GAN
def train_gan(generator, discriminator, dataloader, num_epochs=100):
criterion = nn.BCELoss()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(dataloader):
batch_size = real_images.size(0)
# Train Discriminator
d_optimizer.zero_grad()
label_real = torch.ones(batch_size)
label_fake = torch.zeros(batch_size)
# Real images
output_real = discriminator(real_images)
d_loss_real = criterion(output_real, label_real)
# Fake images
noise = torch.randn(batch_size, latent_dim, 1, 1)
fake_images = generator(noise)
output_fake = discriminator(fake_images.detach())
d_loss_fake = criterion(output_fake, label_fake)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
# Train Generator
g_optimizer.zero_grad()
output_fake = discriminator(fake_images)
g_loss = criterion(output_fake, label_real)
g_loss.backward()
g_optimizer.step()
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
Flow-based Models
Normalizing Flows
class PlanarFlow(nn.Module):
def __init__(self, dim):
super(PlanarFlow, self).__init__()
self.w = nn.Parameter(torch.randn(1, dim))
self.u = nn.Parameter(torch.randn(1, dim))
self.b = nn.Parameter(torch.randn(1))
def forward(self, z):
# f(z) = z + u * h(w^T z + b)
activation = F.tanh(torch.mm(z, self.w.t()) + self.b)
return z + self.u * activation
def log_det_jacobian(self, z):
activation = F.tanh(torch.mm(z, self.w.t()) + self.b)
psi = (1 - activation**2) * self.w
det = 1 + torch.mm(psi, self.u.t())
return torch.log(torch.abs(det) + 1e-6)
class NormalizingFlow(nn.Module):
def __init__(self, dim, n_flows=4):
super(NormalizingFlow, self).__init__()
self.flows = nn.ModuleList([PlanarFlow(dim) for _ in range(n_flows)])
def forward(self, z):
log_det_sum = 0
for flow in self.flows:
log_det_sum += flow.log_det_jacobian(z)
z = flow(z)
return z, log_det_sum
Diffusion Models
Diffusion Process
class DiffusionModel(nn.Module):
def __init__(self, n_steps=1000, beta_start=1e-4, beta_end=0.02):
super().__init__()
self.n_steps = n_steps
self.beta = torch.linspace(beta_start, beta_end, n_steps)
self.alpha = 1 - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
def forward_diffusion(self, x_0, t):
# Sample noise
epsilon = torch.randn_like(x_0)
# Get alpha_bar for timestep t
alpha_bar_t = self.alpha_bar[t]
# Forward process
x_t = torch.sqrt(alpha_bar_t) * x_0 + \
torch.sqrt(1 - alpha_bar_t) * epsilon
return x_t, epsilon
def reverse_diffusion(self, model, x_t, t):
# Predict noise
predicted_noise = model(x_t, t)
# Get parameters for this timestep
alpha_t = self.alpha[t]
alpha_bar_t = self.alpha_bar[t]
beta_t = self.beta[t]
# Reverse process
x_t_prev = (1 / torch.sqrt(alpha_t)) * \
(x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise)
if t > 0:
noise = torch.randn_like(x_t)
x_t_prev += torch.sqrt(beta_t) * noise
return x_t_prev
Evaluation Metrics
def calculate_inception_score(generated_images, model):
# Calculate inception score
scores = []
for images in generated_images:
with torch.no_grad():
preds = F.softmax(model(images), dim=1)
p_y = preds.mean(dim=0)
kl_div = preds * (torch.log(preds) - torch.log(p_y))
scores.append(kl_div.sum(dim=1))
return torch.exp(torch.mean(torch.stack(scores)))
def calculate_fid_score(real_features, generated_features):
# Calculate statistics
mu1, sigma1 = real_features.mean(dim=0), torch_cov(real_features)
mu2, sigma2 = generated_features.mean(dim=0), torch_cov(generated_features)
# Calculate FID
diff = mu1 - mu2
covmean = sqrtm(sigma1.mm(sigma2))
return (diff.dot(diff) + torch.trace(sigma1) +
torch.trace(sigma2) - 2 * torch.trace(covmean))
Best Practices
- Model Architecture
def initialize_weights(model):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
- Training Stability
def gradient_penalty(discriminator, real_samples, fake_samples):
# Random interpolation
alpha = torch.rand(real_samples.size(0), 1, 1, 1)
interpolates = alpha * real_samples + (1 - alpha) * fake_samples
interpolates.requires_grad_(True)
# Calculate gradients
d_interpolates = discriminator(interpolates)
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(d_interpolates),
create_graph=True,
retain_graph=True
)[0]
# Calculate penalty
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
penalty = ((gradient_norm - 1) ** 2).mean()
return penalty
- Sample Generation
def generate_samples(model, n_samples, latent_dim):
model.eval()
with torch.no_grad():
# Sample from latent space
z = torch.randn(n_samples, latent_dim)
# Generate samples
samples = model(z)
# Post-process if needed
samples = (samples + 1) / 2 # Scale from [-1,1] to [0,1]
return samples