PyTorch Generative Adversarial Networks
Introduction to GAN
Generative Adversarial Networks (GAN) is a deep learning architecture proposed by Ian Goodfellow in 2014. GAN trains two neural networks through adversarial training to generate realistic data.
python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoaderBasic GAN Implementation
1. Generator Network
python
class Generator(nn.Module):
def __init__(self, latent_dim=100, img_channels=1, img_size=28):
super(Generator, self).__init__()
self.latent_dim = latent_dim
self.img_channels = img_channels
self.img_size = img_size
# Calculate output size of first layer
self.init_size = img_size // 4
self.l1 = nn.Sequential(
nn.Linear(latent_dim, 128 * self.init_size ** 2)
)
# Upsampling layers
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, img_channels, 3, stride=1, padding=1),
nn.Tanh()
)
def forward(self, z):
# Convert noise vector to feature map
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
# Generate image through convolution layers
img = self.conv_blocks(out)
return img
# Test generator
latent_dim = 100
generator = Generator(latent_dim=latent_dim, img_channels=1, img_size=28)
# Generate random noise
z = torch.randn(4, latent_dim)
fake_imgs = generator(z)
print(f"Generated image shape: {fake_imgs.shape}")2. Discriminator Network
python
class Discriminator(nn.Module):
def __init__(self, img_channels=1, img_size=28):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
"""Discriminator basic block"""
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
block.extend([nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)])
return block
self.model = nn.Sequential(
*discriminator_block(img_channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# Calculate output size of convolution
ds_size = img_size // 2 ** 4
self.adv_layer = nn.Sequential(
nn.Linear(128 * ds_size ** 2, 1),
nn.Sigmoid()
)
def forward(self, img):
out = self.model(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
return validity
# Test discriminator
discriminator = Discriminator(img_channels=1, img_size=28)
validity = discriminator(fake_imgs)
print(f"Discriminator output shape: {validity.shape}")DCGAN Implementation
1. DCGAN Generator
python
class DCGANGenerator(nn.Module):
def __init__(self, latent_dim=100, img_channels=3, feature_maps=64):
super(DCGANGenerator, self).__init__()
self.main = nn.Sequential(
# Input is latent_dim dimensional noise vector
nn.ConvTranspose2d(latent_dim, feature_maps * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(feature_maps * 8),
nn.ReLU(True),
# State size: (feature_maps*8) x 4 x 4
nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 4),
nn.ReLU(True),
# State size: (feature_maps*4) x 8 x 8
nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 2),
nn.ReLU(True),
# State size: (feature_maps*2) x 16 x 16
nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps),
nn.ReLU(True),
# State size: (feature_maps) x 32 x 32
nn.ConvTranspose2d(feature_maps, img_channels, 4, 2, 1, bias=False),
nn.Tanh()
# Output size: (img_channels) x 64 x 64
)
def forward(self, input):
return self.main(input)WGAN Implementation
1. WGAN Loss Function
python
class WGAN:
def __init__(self, generator, discriminator, lr=0.00005, clip_value=0.01):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.generator = generator.to(self.device)
self.discriminator = discriminator.to(self.device)
# WGAN uses RMSprop optimizer
self.optimizer_G = optim.RMSprop(self.generator.parameters(), lr=lr)
self.optimizer_D = optim.RMSprop(self.discriminator.parameters(), lr=lr)
self.clip_value = clip_value
self.latent_dim = 100
def train_discriminator(self, real_imgs, n_critic=5):
"""Train discriminator (critic)"""
d_losses = []
for _ in range(n_critic):
self.optimizer_D.zero_grad()
batch_size = real_imgs.size(0)
# Real image loss
real_validity = self.discriminator(real_imgs)
# Generate fake images
z = torch.randn(batch_size, self.latent_dim, 1, 1, device=self.device)
fake_imgs = self.generator(z).detach()
fake_validity = self.discriminator(fake_imgs)
# WGAN loss
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
d_loss.backward()
self.optimizer_D.step()
# Weight clipping
for p in self.discriminator.parameters():
p.data.clamp_(-self.clip_value, self.clip_value)
d_losses.append(d_loss.item())
return np.mean(d_losses)
def train_generator(self, batch_size):
"""Train generator"""
self.optimizer_G.zero_grad()
# Generate fake images
z = torch.randn(batch_size, self.latent_dim, 1, 1, device=self.device)
fake_imgs = self.generator(z)
# WGAN generator loss
fake_validity = self.discriminator(fake_imgs)
g_loss = -torch.mean(fake_validity)
g_loss.backward()
self.optimizer_G.step()
return g_loss.item(), fake_imgsConditional GAN (cGAN)
1. Conditional Generator
python
class ConditionalGenerator(nn.Module):
def __init__(self, latent_dim=100, num_classes=10, img_channels=1, img_size=28):
super(ConditionalGenerator, self).__init__()
self.label_emb = nn.Embedding(num_classes, num_classes)
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim + num_classes, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, img_channels * img_size * img_size),
nn.Tanh()
)
self.img_size = img_size
self.img_channels = img_channels
def forward(self, noise, labels):
# Concatenate noise and label embeddings
gen_input = torch.cat((self.label_emb(labels), noise), -1)
img = self.model(gen_input)
img = img.view(img.size(0), self.img_channels, self.img_size, self.img_size)
return imgTraining Techniques and Optimization
1. Gradient Penalty
python
def gradient_penalty(self, real_imgs, fake_imgs):
"""Compute gradient penalty"""
batch_size = real_imgs.size(0)
# Random interpolation
alpha = torch.rand(batch_size, 1, 1, 1, device=self.device)
interpolates = alpha * real_imgs + (1 - alpha) * fake_imgs
interpolates.requires_grad_(True)
# Compute discriminator output for interpolation
d_interpolates = self.discriminator(interpolates)
# Compute gradients
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(d_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True
)[0]
# Compute gradient penalty
gradients = gradients.view(batch_size, -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penaltyEvaluation Metrics
1. FID Score
python
import scipy.linalg
from torchvision.models import inception_v3
class FIDCalculator:
def __init__(self, device):
self.device = device
self.inception = inception_v3(pretrained=True, transform_input=False).to(device)
self.inception.eval()
def get_activations(self, images):
"""Get Inception features"""
with torch.no_grad():
# Resize images to 299x299
if images.size(-1) != 299:
images = F.interpolate(images, size=299, mode='bilinear', align_corners=False)
# Get features
features = self.inception(images)
return features.cpu().numpy()
def calculate_fid(self, real_images, fake_images):
"""Compute FID score"""
# Get features for real and generated images
real_features = self.get_activations(real_images)
fake_features = self.get_activations(fake_images)
# Compute mean and covariance
mu_real = np.mean(real_features, axis=0)
sigma_real = np.cov(real_features, rowvar=False)
mu_fake = np.mean(fake_features, axis=0)
sigma_fake = np.cov(fake_features, rowvar=False)
# Compute FID
diff = mu_real - mu_fake
covmean, _ = scipy.linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff.dot(diff) + np.trace(sigma_real + sigma_fake - 2 * covmean)
return fid
# Usage example
fid_calculator = FIDCalculator(device)
fid_score = fid_calculator.calculate_fid(real_images, fake_images)
print(f"FID score: {fid_score:.2f}")Summary
Generative Adversarial Networks are important techniques in deep learning. This chapter introduced:
- Basic GAN: Basic structures and training process of generator and discriminator
- DCGAN: Implementation of deep convolutional GAN
- WGAN: Wasserstein GAN and its improved versions
- Conditional GAN: Conditional generation models
- Training Techniques: Progressive training, spectral normalization, self-attention, etc.
- Evaluation Metrics: FID, IS and other generation quality evaluation methods
Mastering GAN techniques will help you innovate in image generation, data augmentation, and other fields!