Skip to content

Generative Adversarial Networks (GAN)

Generative Adversarial Networks (GAN) are a deep learning model proposed by Ian Goodfellow in 2014. GANs learn data distribution through adversarial training of two neural networks, capable of generating high-quality synthetic data.

GAN Basics

What is a GAN?

A GAN consists of two networks:

  • Generator: Learns to generate fake data that resembles real data
  • Discriminator: Learns to distinguish between real and generated data

These two networks compete against each other during training, eventually reaching Nash equilibrium.

python
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
import os

# Set random seeds
tf.random.set_seed(42)
np.random.seed(42)

# Basic GAN architecture example
def create_generator(latent_dim, output_shape):
    """
    Create generator network
    """
    model = keras.Sequential([
        keras.layers.Dense(128, activation='relu', input_shape=(latent_dim,)),
        keras.layers.BatchNormalization(),
        keras.layers.Dense(256, activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Dense(np.prod(output_shape), activation='tanh'),
        keras.layers.Reshape(output_shape)
    ])
    return model

def create_discriminator(input_shape):
    """
    Create discriminator network
    """
    model = keras.Sequential([
        keras.layers.Flatten(input_shape=input_shape),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.Dropout(0.3),
        keras.layers.Dense(256, activation='relu'),
        keras.layers.Dropout(0.3),
        keras.layers.Dense(1, activation='sigmoid')
    ])
    return model

# Example: Create simple GAN
latent_dim = 100
img_shape = (28, 28, 1)

generator = create_generator(latent_dim, img_shape)
discriminator = create_discriminator(img_shape)

print("Generator structure:")
generator.summary()
print("\nDiscriminator structure:")
discriminator.summary()

DCGAN Implementation

python
def create_dcgan_generator(latent_dim):
    """
    Create DCGAN generator
    """
    model = keras.Sequential([
        # Input layer
        keras.layers.Dense(7*7*256, use_bias=False, input_shape=(latent_dim,)),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),

        # Reshape to 7x7x256
        keras.layers.Reshape((7, 7, 256)),

        # Upsample to 14x14x128
        keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1),
                                   padding='same', use_bias=False),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),

        # Upsample to 14x14x64
        keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2),
                                   padding='same', use_bias=False),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),

        # Upsample to 28x28x1
        keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2),
                                   padding='same', use_bias=False,
                                   activation='tanh')
    ])

    return model

def create_dcgan_discriminator():
    """
    Create DCGAN discriminator
    """
    model = keras.Sequential([
        # 28x28x1 -> 14x14x64
        keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                           input_shape=[28, 28, 1]),
        keras.layers.LeakyReLU(),
        keras.layers.Dropout(0.3),

        # 14x14x64 -> 7x7x128
        keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        keras.layers.LeakyReLU(),
        keras.layers.Dropout(0.3),

        # Flatten and output
        keras.layers.Flatten(),
        keras.layers.Dense(1)
    ])

    return model

# Create DCGAN model
dcgan_generator = create_dcgan_generator(100)
dcgan_discriminator = create_dcgan_discriminator()

print("DCGAN Generator:")
dcgan_generator.summary()
print("\nDCGAN Discriminator:")
dcgan_discriminator.summary()

Loss Functions

python
# Binary cross-entropy loss
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    """
    Discriminator loss function
    """
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    """
    Generator loss function
    """
    return cross_entropy(tf.ones_like(fake_output), fake_output)

# WGAN loss function
def wasserstein_discriminator_loss(real_output, fake_output):
    """
    Wasserstein discriminator loss
    """
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def wasserstein_generator_loss(fake_output):
    """
    Wasserstein generator loss
    """
    return -tf.reduce_mean(fake_output)

# LSGAN loss function
def lsgan_discriminator_loss(real_output, fake_output):
    """
    LSGAN discriminator loss
    """
    real_loss = tf.reduce_mean(tf.square(real_output - 1))
    fake_loss = tf.reduce_mean(tf.square(fake_output))
    return 0.5 * (real_loss + fake_loss)

def lsgan_generator_loss(fake_output):
    """
    LSGAN generator loss
    """
    return 0.5 * tf.reduce_mean(tf.square(fake_output - 1))

Training Loop

python
class GAN:
    def __init__(self, generator, discriminator, latent_dim):
        self.generator = generator
        self.discriminator = discriminator
        self.latent_dim = latent_dim

        # Optimizers
        self.generator_optimizer = keras.optimizers.Adam(1e-4)
        self.discriminator_optimizer = keras.optimizers.Adam(1e-4)

        # Loss tracking
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

    @tf.function
    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]

        # Generate random noise
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Train discriminator
        with tf.GradientTape() as disc_tape:
            # Generate fake images
            fake_images = self.generator(random_latent_vectors, training=True)

            # Discriminator predictions
            real_predictions = self.discriminator(real_images, training=True)
            fake_predictions = self.discriminator(fake_images, training=True)

            # Calculate discriminator loss
            disc_loss = discriminator_loss(real_predictions, fake_predictions)

        # Calculate discriminator gradients and update
        disc_gradients = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        self.discriminator_optimizer.apply_gradients(
            zip(disc_gradients, self.discriminator.trainable_variables)
        )

        # Train generator
        with tf.GradientTape() as gen_tape:
            # Generate fake images
            fake_images = self.generator(random_latent_vectors, training=True)

            # Discriminator predictions
            fake_predictions = self.discriminator(fake_images, training=True)

            # Calculate generator loss
            gen_loss = generator_loss(fake_predictions)

        # Calculate generator gradients and update
        gen_gradients = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.generator_optimizer.apply_gradients(
            zip(gen_gradients, self.generator.trainable_variables)
        )

        # Update loss tracking
        self.gen_loss_tracker.update_state(gen_loss)
        self.disc_loss_tracker.update_state(disc_loss)

        return {
            "generator_loss": self.gen_loss_tracker.result(),
            "discriminator_loss": self.disc_loss_tracker.result(),
        }

# Create GAN instance
gan = GAN(dcgan_generator, dcgan_discriminator, latent_dim=100)

Data Preparation and Training

python
def prepare_mnist_data():
    """
    Prepare MNIST data
    """
    (x_train, _), (_, _) = keras.datasets.mnist.load_data()

    # Normalize to [-1, 1]
    x_train = x_train.astype('float32')
    x_train = (x_train - 127.5) / 127.5

    # Add channel dimension
    x_train = np.expand_dims(x_train, axis=-1)

    return x_train

def create_dataset(data, batch_size=256, buffer_size=60000):
    """
    Create training dataset
    """
    dataset = tf.data.Dataset.from_tensor_slices(data)
    dataset = dataset.shuffle(buffer_size).batch(batch_size)
    return dataset

# Prepare data
train_images = prepare_mnist_data()
train_dataset = create_dataset(train_images)

print(f"Training data shape: {train_images.shape}")

def train_gan(gan, dataset, epochs=100):
    """
    Train GAN model
    """
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")

        # Train one epoch
        for batch in dataset:
            losses = gan.train_step(batch)

        # Generate samples every 10 epochs
        if (epoch + 1) % 10 == 0:
            generate_and_save_images(gan.generator, epoch + 1)
            print(f"Generator Loss: {losses['generator_loss']:.4f}")
            print(f"Discriminator Loss: {losses['discriminator_loss']:.4f}")

def generate_and_save_images(generator, epoch, num_examples=16):
    """
    Generate and save images
    """
    # Generate random noise
    noise = tf.random.normal([num_examples, 100])

    # Generate images
    generated_images = generator(noise, training=False)

    # Visualize
    fig = plt.figure(figsize=(4, 4))

    for i in range(generated_images.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(generated_images[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.suptitle(f'Epoch {epoch}')
    plt.tight_layout()
    plt.show()

# Start training (example, actual training needs more epochs)
# train_gan(gan, train_dataset, epochs=50)

Conditional GAN (cGAN)

python
def create_conditional_generator(latent_dim, num_classes, img_shape):
    """
    Create conditional generator
    """
    # Noise input
    noise_input = keras.layers.Input(shape=(latent_dim,))

    # Label input
    label_input = keras.layers.Input(shape=(1,))
    label_embedding = keras.layers.Embedding(num_classes, 50)(label_input)
    label_embedding = keras.layers.Flatten()(label_embedding)

    # Merge noise and label
    merged_input = keras.layers.Concatenate()([noise_input, label_embedding])

    # Generator network
    x = keras.layers.Dense(7*7*256, use_bias=False)(merged_input)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU()(x)
    x = keras.layers.Reshape((7, 7, 256))(x)

    x = keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1),
                                   padding='same', use_bias=False)(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU()(x)

    x = keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2),
                                   padding='same', use_bias=False)(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU()(x)

    output = keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2),
                                        padding='same', use_bias=False,
                                        activation='tanh')(x)

    model = keras.Model([noise_input, label_input], output)
    return model

def create_conditional_discriminator(img_shape, num_classes):
    """
    Create conditional discriminator
    """
    # Image input
    img_input = keras.layers.Input(shape=img_shape)

    # Label input
    label_input = keras.layers.Input(shape=(1,))
    label_embedding = keras.layers.Embedding(num_classes, 50)(label_input)
    label_embedding = keras.layers.Flatten()(label_embedding)
    label_embedding = keras.layers.Dense(np.prod(img_shape))(label_embedding)
    label_embedding = keras.layers.Reshape(img_shape)(label_embedding)

    # Merge image and label
    merged_input = keras.layers.Concatenate()([img_input, label_embedding])

    # Discriminator network
    x = keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')(merged_input)
    x = keras.layers.LeakyReLU()(x)
    x = keras.layers.Dropout(0.3)(x)

    x = keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')(x)
    x = keras.layers.LeakyReLU()(x)
    x = keras.layers.Dropout(0.3)(x)

    x = keras.layers.Flatten()(x)
    output = keras.layers.Dense(1)(x)

    model = keras.Model([img_input, label_input], output)
    return model

# Create conditional GAN
cond_generator = create_conditional_generator(100, 10, (28, 28, 1))
cond_discriminator = create_conditional_discriminator((28, 28, 1), 10)

CycleGAN Implementation

python
def create_cyclegan_generator():
    """
    Create CycleGAN generator (ResNet architecture)
    """
    def residual_block(x, filters):
        shortcut = x

        x = keras.layers.Conv2D(filters, 3, padding='same')(x)
        x = keras.layers.BatchNormalization()(x)
        x = keras.layers.ReLU()(x)

        x = keras.layers.Conv2D(filters, 3, padding='same')(x)
        x = keras.layers.BatchNormalization()(x)

        x = keras.layers.Add()([shortcut, x])
        x = keras.layers.ReLU()(x)

        return x

    inputs = keras.layers.Input(shape=(256, 256, 3))

    # Encoder
    x = keras.layers.Conv2D(64, 7, padding='same')(inputs)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)

    # Downsampling
    x = keras.layers.Conv2D(128, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)

    x = keras.layers.Conv2D(256, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)

    # Residual blocks
    for _ in range(9):
        x = residual_block(x, 256)

    # Upsampling
    x = keras.layers.Conv2DTranspose(128, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)

    x = keras.layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)

    outputs = keras.layers.Conv2D(3, 7, padding='same', activation='tanh')(x)

    model = keras.Model(inputs, outputs)
    return model

def create_cyclegan_discriminator():
    """
    Create CycleGAN discriminator (PatchGAN)
    """
    inputs = keras.layers.Input(shape=(256, 256, 3))

    x = keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
    x = keras.layers.LeakyReLU(0.2)(x)

    x = keras.layers.Conv2D(128, 4, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)

    x = keras.layers.Conv2D(256, 4, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)

    x = keras.layers.Conv2D(512, 4, strides=1, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)

    outputs = keras.layers.Conv2D(1, 4, strides=1, padding='same')(x)

    model = keras.Model(inputs, outputs)
    return model

# CycleGAN loss functions
def cycle_consistency_loss(real_image, cycled_image, lambda_cycle=10.0):
    """
    Cycle consistency loss
    """
    loss = tf.reduce_mean(tf.abs(real_image - cycled_image))
    return lambda_cycle * loss

def identity_loss(real_image, same_image, lambda_identity=0.5):
    """
    Identity loss
    """
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return lambda_identity * loss

StyleGAN Basics

python
def create_stylegan_mapping_network(latent_dim=512, num_layers=8):
    """
    Create StyleGAN mapping network
    """
    model = keras.Sequential()

    for _ in range(num_layers):
        model.add(keras.layers.Dense(latent_dim, activation='relu'))

    model.build(input_shape=(None, latent_dim))
    return model

def adaptive_instance_normalization(content_features, style_features):
    """
    Adaptive instance normalization
    """
    # Calculate mean and std of content features
    content_mean = tf.reduce_mean(content_features, axis=[1, 2], keepdims=True)
    content_std = tf.math.reduce_std(content_features, axis=[1, 2], keepdims=True)

    # Calculate mean and std of style features
    style_mean = tf.reduce_mean(style_features, axis=[1, 2], keepdims=True)
    style_std = tf.math.reduce_std(style_features, axis=[1, 2], keepdims=True)

    # Normalize content features
    normalized_content = (content_features - content_mean) / (content_std + 1e-8)

    # Apply style statistics
    stylized_features = normalized_content * style_std + style_mean

    return stylized_features

class StyleGANGenerator(keras.Model):
    """
    Simplified StyleGAN generator
    """
    def __init__(self, latent_dim=512, num_layers=6):
        super(StyleGANGenerator, self).__init__()
        self.latent_dim = latent_dim
        self.num_layers = num_layers

        # Mapping network
        self.mapping_network = create_stylegan_mapping_network(latent_dim)

        # Constant input
        self.constant_input = self.add_weight(
            shape=(1, 4, 4, 512),
            initializer='random_normal',
            trainable=True,
            name='constant_input'
        )

        # Generator layers
        self.conv_layers = []
        self.adain_layers = []

        channels = [512, 512, 256, 128, 64, 32]
        for i in range(num_layers):
            self.conv_layers.append(
                keras.layers.Conv2D(channels[i], 3, padding='same', activation='relu')
            )
            self.adain_layers.append(
                keras.layers.Dense(channels[i] * 2)  # For generating mean and std
            )

    def call(self, latent_codes):
        batch_size = tf.shape(latent_codes)[0]

        # Mapping network
        w = self.mapping_network(latent_codes)

        # Start from constant
        x = tf.tile(self.constant_input, [batch_size, 1, 1, 1])

        # Generate layer by layer
        for i in range(self.num_layers):
            # Convolution
            x = self.conv_layers[i](x)

            # AdaIN
            style_params = self.adain_layers[i](w)
            style_mean, style_std = tf.split(style_params, 2, axis=-1)

            # Apply AdaIN
            x_mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
            x_std = tf.math.reduce_std(x, axis=[1, 2], keepdims=True)
            x = (x - x_mean) / (x_std + 1e-8)
            x = x * tf.expand_dims(tf.expand_dims(style_std, 1), 1) + \
                tf.expand_dims(tf.expand_dims(style_mean, 1), 1)

            # Upsampling (except last layer)
            if i < self.num_layers - 1:
                x = keras.layers.UpSampling2D()(x)

        return x

Evaluation Metrics

python
def calculate_fid_score(real_images, generated_images, model_name='inception_v3'):
    """
    Calculate FID score (Fréchet Inception Distance)
    """
    # Load pre-trained Inception model
    inception_model = keras.applications.InceptionV3(
        include_top=False,
        pooling='avg',
        input_shape=(299, 299, 3)
    )

    def preprocess_images(images):
        # Resize images to 299x299
        images = tf.image.resize(images, [299, 299])
        # Preprocess
        images = keras.applications.inception_v3.preprocess_input(images)
        return images

    # Preprocess images
    real_images = preprocess_images(real_images)
    generated_images = preprocess_images(generated_images)

    # Extract features
    real_features = inception_model.predict(real_images)
    generated_features = inception_model.predict(generated_images)

    # Calculate mean and covariance
    mu_real = np.mean(real_features, axis=0)
    sigma_real = np.cov(real_features, rowvar=False)

    mu_gen = np.mean(generated_features, axis=0)
    sigma_gen = np.cov(generated_features, rowvar=False)

    # Calculate FID
    diff = mu_real - mu_gen
    covmean = scipy.linalg.sqrtm(sigma_real.dot(sigma_gen))

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff.dot(diff) + np.trace(sigma_real + sigma_gen - 2 * covmean)
    return fid

def calculate_inception_score(generated_images, num_splits=10):
    """
    Calculate Inception Score
    """
    # Load Inception model
    inception_model = keras.applications.InceptionV3(
        include_top=True,
        input_shape=(299, 299, 3)
    )

    # Preprocess images
    images = tf.image.resize(generated_images, [299, 299])
    images = keras.applications.inception_v3.preprocess_input(images)

    # Get predictions
    predictions = inception_model.predict(images)

    # Calculate IS
    scores = []
    for i in range(num_splits):
        part = predictions[i * len(predictions) // num_splits:
                         (i + 1) * len(predictions) // num_splits]

        # Calculate KL divergence
        py = np.mean(part, axis=0)
        kl_div = part * (np.log(part + 1e-8) - np.log(py + 1e-8))
        kl_div = np.mean(np.sum(kl_div, axis=1))
        scores.append(np.exp(kl_div))

    return np.mean(scores), np.std(scores)

Training Tips and Best Practices

python
# 1. Progressive training
class ProgressiveGAN:
    def __init__(self):
        self.current_resolution = 4
        self.max_resolution = 256

    def grow_network(self):
        """
        Gradually increase network resolution
        """
        if self.current_resolution < self.max_resolution:
            self.current_resolution *= 2
            # Add new layers to generator and discriminator

    def fade_in_new_layers(self, alpha):
        """
        Fade in new layers
        """
        # Use alpha to blend old and new outputs
        pass

# 2. Spectral normalization
def spectral_normalization(layer):
    """
    Spectral normalization decorator
    """
    return keras.utils.get_custom_objects()['SpectralNormalization'](layer)

# 3. Self-attention mechanism
class SelfAttention(keras.layers.Layer):
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.query_conv = keras.layers.Conv2D(channels // 8, 1)
        self.key_conv = keras.layers.Conv2D(channels // 8, 1)
        self.value_conv = keras.layers.Conv2D(channels, 1)
        self.gamma = self.add_weight(shape=(), initializer='zeros', trainable=True)

    def call(self, x):
        batch_size, height, width, channels = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]

        # Calculate query, key, value
        query = self.query_conv(x)
        key = self.key_conv(x)
        value = self.value_conv(x)

        # Reshape to matrix form
        query = tf.reshape(query, [batch_size, -1, channels // 8])
        key = tf.reshape(key, [batch_size, -1, channels // 8])
        value = tf.reshape(value, [batch_size, -1, channels])

        # Calculate attention
        attention = tf.nn.softmax(tf.matmul(query, key, transpose_b=True))
        out = tf.matmul(attention, value)
        out = tf.reshape(out, [batch_size, height, width, channels])

        # Residual connection
        out = self.gamma * out + x
        return out

# 4. Gradient penalty (WGAN-GP)
def gradient_penalty(discriminator, real_images, fake_images, batch_size):
    """
    Calculate gradient penalty
    """
    alpha = tf.random.uniform([batch_size, 1, 1, 1], 0., 1.)
    interpolated = alpha * real_images + (1 - alpha) * fake_images

    with tf.GradientTape() as tape:
        tape.watch(interpolated)
        pred = discriminator(interpolated, training=True)

    gradients = tape.gradient(pred, interpolated)
    gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
    gradient_penalty = tf.reduce_mean((gradients_norm - 1.) ** 2)

    return gradient_penalty

Real Application Examples

python
def image_to_image_translation():
    """
    Image-to-image translation example
    """
    # Create Pix2Pix model
    def create_pix2pix_generator():
        # U-Net architecture
        inputs = keras.layers.Input(shape=(256, 256, 3))

        # Encoder
        down_stack = [
            keras.layers.Conv2D(64, 4, strides=2, padding='same', use_bias=False),
            keras.layers.Conv2D(128, 4, strides=2, padding='same', use_bias=False),
            keras.layers.Conv2D(256, 4, strides=2, padding='same', use_bias=False),
            keras.layers.Conv2D(512, 4, strides=2, padding='same', use_bias=False),
        ]

        # Decoder
        up_stack = [
            keras.layers.Conv2DTranspose(256, 4, strides=2, padding='same', use_bias=False),
            keras.layers.Conv2DTranspose(128, 4, strides=2, padding='same', use_bias=False),
            keras.layers.Conv2DTranspose(64, 4, strides=2, padding='same', use_bias=False),
        ]

        x = inputs

        # Downsampling
        skips = []
        for down in down_stack:
            x = down(x)
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.LeakyReLU()(x)
            skips.append(x)

        skips = reversed(skips[:-1])

        # Upsampling
        for up, skip in zip(up_stack, skips):
            x = up(x)
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.ReLU()(x)
            x = keras.layers.Concatenate()([x, skip])

        # Last layer
        last = keras.layers.Conv2DTranspose(3, 4, strides=2, padding='same', activation='tanh')
        x = last(x)

        return keras.Model(inputs=inputs, outputs=x)

    return create_pix2pix_generator()

def super_resolution_gan():
    """
    Super-resolution GAN example
    """
    def create_srgan_generator():
        def residual_block(x):
            shortcut = x
            x = keras.layers.Conv2D(64, 3, padding='same')(x)
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.PReLU()(x)
            x = keras.layers.Conv2D(64, 3, padding='same')(x)
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.Add()([shortcut, x])
            return x

        inputs = keras.layers.Input(shape=(None, None, 3))

        # Initial convolution
        x = keras.layers.Conv2D(64, 9, padding='same')(inputs)
        x = keras.layers.PReLU()(x)

        # Residual blocks
        for _ in range(16):
            x = residual_block(x)

        # Upsampling
        x = keras.layers.Conv2D(256, 3, padding='same')(x)
        x = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)
        x = keras.layers.PReLU()(x)

        x = keras.layers.Conv2D(256, 3, padding='same')(x)
        x = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)
        x = keras.layers.PReLU()(x)

        # Output
        outputs = keras.layers.Conv2D(3, 9, padding='same', activation='tanh')(x)

        return keras.Model(inputs, outputs)

    return create_srgan_generator()

Summary

GAN is one of the most innovative technologies in deep learning, ushering in a new era of generative models. From basic GAN to advanced variants like StyleGAN and BigGAN, GANs have wide applications in image generation, style transfer, data augmentation, and other fields.

Key Points:

  1. Adversarial Training: The game process between generator and discriminator
  2. Loss Function Design: Different loss functions for different scenarios
  3. Training Stability: Requires careful tuning of hyperparameters and training strategies
  4. Evaluation Metrics: FID, IS, and other metrics help evaluate generation quality

Next chapter we will learn practical projects and apply theoretical knowledge to specific projects.

Content is for learning and research only.