Skip to content

PyTorch Generative Adversarial Networks

Introduction to GAN

Generative Adversarial Networks (GAN) is a deep learning architecture proposed by Ian Goodfellow in 2014. GAN trains two neural networks through adversarial training to generate realistic data.

python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

Basic GAN Implementation

1. Generator Network

python
class Generator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=1, img_size=28):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.img_channels = img_channels
        self.img_size = img_size
        
        # Calculate output size of first layer
        self.init_size = img_size // 4
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, 128 * self.init_size ** 2)
        )
        
        # Upsampling layers
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, img_channels, 3, stride=1, padding=1),
            nn.Tanh()
        )
    
    def forward(self, z):
        # Convert noise vector to feature map
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        
        # Generate image through convolution layers
        img = self.conv_blocks(out)
        
        return img

# Test generator
latent_dim = 100
generator = Generator(latent_dim=latent_dim, img_channels=1, img_size=28)

# Generate random noise
z = torch.randn(4, latent_dim)
fake_imgs = generator(z)
print(f"Generated image shape: {fake_imgs.shape}")

2. Discriminator Network

python
class Discriminator(nn.Module):
    def __init__(self, img_channels=1, img_size=28):
        super(Discriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, bn=True):
            """Discriminator basic block"""
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            block.extend([nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)])
            return block
        
        self.model = nn.Sequential(
            *discriminator_block(img_channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        
        # Calculate output size of convolution
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        
        return validity

# Test discriminator
discriminator = Discriminator(img_channels=1, img_size=28)
validity = discriminator(fake_imgs)
print(f"Discriminator output shape: {validity.shape}")

DCGAN Implementation

1. DCGAN Generator

python
class DCGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=3, feature_maps=64):
        super(DCGANGenerator, self).__init__()
        
        self.main = nn.Sequential(
            # Input is latent_dim dimensional noise vector
            nn.ConvTranspose2d(latent_dim, feature_maps * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.ReLU(True),
            
            # State size: (feature_maps*8) x 4 x 4
            nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(True),
            
            # State size: (feature_maps*4) x 8 x 8
            nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(True),
            
            # State size: (feature_maps*2) x 16 x 16
            nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps),
            nn.ReLU(True),
            
            # State size: (feature_maps) x 32 x 32
            nn.ConvTranspose2d(feature_maps, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output size: (img_channels) x 64 x 64
        )
    
    def forward(self, input):
        return self.main(input)

WGAN Implementation

1. WGAN Loss Function

python
class WGAN:
    def __init__(self, generator, discriminator, lr=0.00005, clip_value=0.01):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.generator = generator.to(self.device)
        self.discriminator = discriminator.to(self.device)
        
        # WGAN uses RMSprop optimizer
        self.optimizer_G = optim.RMSprop(self.generator.parameters(), lr=lr)
        self.optimizer_D = optim.RMSprop(self.discriminator.parameters(), lr=lr)
        
        self.clip_value = clip_value
        self.latent_dim = 100
    
    def train_discriminator(self, real_imgs, n_critic=5):
        """Train discriminator (critic)"""
        d_losses = []
        
        for _ in range(n_critic):
            self.optimizer_D.zero_grad()
            
            batch_size = real_imgs.size(0)
            
            # Real image loss
            real_validity = self.discriminator(real_imgs)
            
            # Generate fake images
            z = torch.randn(batch_size, self.latent_dim, 1, 1, device=self.device)
            fake_imgs = self.generator(z).detach()
            fake_validity = self.discriminator(fake_imgs)
            
            # WGAN loss
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
            
            d_loss.backward()
            self.optimizer_D.step()
            
            # Weight clipping
            for p in self.discriminator.parameters():
                p.data.clamp_(-self.clip_value, self.clip_value)
            
            d_losses.append(d_loss.item())
        
        return np.mean(d_losses)
    
    def train_generator(self, batch_size):
        """Train generator"""
        self.optimizer_G.zero_grad()
        
        # Generate fake images
        z = torch.randn(batch_size, self.latent_dim, 1, 1, device=self.device)
        fake_imgs = self.generator(z)
        
        # WGAN generator loss
        fake_validity = self.discriminator(fake_imgs)
        g_loss = -torch.mean(fake_validity)
        
        g_loss.backward()
        self.optimizer_G.step()
        
        return g_loss.item(), fake_imgs

Conditional GAN (cGAN)

1. Conditional Generator

python
class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim=100, num_classes=10, img_channels=1, img_size=28):
        super(ConditionalGenerator, self).__init__()
        
        self.label_emb = nn.Embedding(num_classes, num_classes)
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim + num_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, img_channels * img_size * img_size),
            nn.Tanh()
        )
        
        self.img_size = img_size
        self.img_channels = img_channels
    
    def forward(self, noise, labels):
        # Concatenate noise and label embeddings
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), self.img_channels, self.img_size, self.img_size)
        return img

Training Techniques and Optimization

1. Gradient Penalty

python
def gradient_penalty(self, real_imgs, fake_imgs):
    """Compute gradient penalty"""
    batch_size = real_imgs.size(0)
    
    # Random interpolation
    alpha = torch.rand(batch_size, 1, 1, 1, device=self.device)
    interpolates = alpha * real_imgs + (1 - alpha) * fake_imgs
    interpolates.requires_grad_(True)
    
    # Compute discriminator output for interpolation
    d_interpolates = self.discriminator(interpolates)
    
    # Compute gradients
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # Compute gradient penalty
    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    return gradient_penalty

Evaluation Metrics

1. FID Score

python
import scipy.linalg
from torchvision.models import inception_v3

class FIDCalculator:
    def __init__(self, device):
        self.device = device
        self.inception = inception_v3(pretrained=True, transform_input=False).to(device)
        self.inception.eval()
    
    def get_activations(self, images):
        """Get Inception features"""
        with torch.no_grad():
            # Resize images to 299x299
            if images.size(-1) != 299:
                images = F.interpolate(images, size=299, mode='bilinear', align_corners=False)
            
            # Get features
            features = self.inception(images)
        
        return features.cpu().numpy()
    
    def calculate_fid(self, real_images, fake_images):
        """Compute FID score"""
        # Get features for real and generated images
        real_features = self.get_activations(real_images)
        fake_features = self.get_activations(fake_images)
        
        # Compute mean and covariance
        mu_real = np.mean(real_features, axis=0)
        sigma_real = np.cov(real_features, rowvar=False)
        
        mu_fake = np.mean(fake_features, axis=0)
        sigma_fake = np.cov(fake_features, rowvar=False)
        
        # Compute FID
        diff = mu_real - mu_fake
        covmean, _ = scipy.linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)
        
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        
        fid = diff.dot(diff) + np.trace(sigma_real + sigma_fake - 2 * covmean)
        
        return fid

# Usage example
fid_calculator = FIDCalculator(device)
fid_score = fid_calculator.calculate_fid(real_images, fake_images)
print(f"FID score: {fid_score:.2f}")

Summary

Generative Adversarial Networks are important techniques in deep learning. This chapter introduced:

  1. Basic GAN: Basic structures and training process of generator and discriminator
  2. DCGAN: Implementation of deep convolutional GAN
  3. WGAN: Wasserstein GAN and its improved versions
  4. Conditional GAN: Conditional generation models
  5. Training Techniques: Progressive training, spectral normalization, self-attention, etc.
  6. Evaluation Metrics: FID, IS and other generation quality evaluation methods

Mastering GAN techniques will help you innovate in image generation, data augmentation, and other fields!

Content is for learning and research only.