Generative Adversarial Networks (GAN)
Generative Adversarial Networks (GAN) are a deep learning model proposed by Ian Goodfellow in 2014. GANs learn data distribution through adversarial training of two neural networks, capable of generating high-quality synthetic data.
GAN Basics
What is a GAN?
A GAN consists of two networks:
- Generator: Learns to generate fake data that resembles real data
- Discriminator: Learns to distinguish between real and generated data
These two networks compete against each other during training, eventually reaching Nash equilibrium.
python
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
import os
# Set random seeds
tf.random.set_seed(42)
np.random.seed(42)
# Basic GAN architecture example
def create_generator(latent_dim, output_shape):
"""
Create generator network
"""
model = keras.Sequential([
keras.layers.Dense(128, activation='relu', input_shape=(latent_dim,)),
keras.layers.BatchNormalization(),
keras.layers.Dense(256, activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.Dense(512, activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.Dense(np.prod(output_shape), activation='tanh'),
keras.layers.Reshape(output_shape)
])
return model
def create_discriminator(input_shape):
"""
Create discriminator network
"""
model = keras.Sequential([
keras.layers.Flatten(input_shape=input_shape),
keras.layers.Dense(512, activation='relu'),
keras.layers.Dropout(0.3),
keras.layers.Dense(256, activation='relu'),
keras.layers.Dropout(0.3),
keras.layers.Dense(1, activation='sigmoid')
])
return model
# Example: Create simple GAN
latent_dim = 100
img_shape = (28, 28, 1)
generator = create_generator(latent_dim, img_shape)
discriminator = create_discriminator(img_shape)
print("Generator structure:")
generator.summary()
print("\nDiscriminator structure:")
discriminator.summary()DCGAN Implementation
python
def create_dcgan_generator(latent_dim):
"""
Create DCGAN generator
"""
model = keras.Sequential([
# Input layer
keras.layers.Dense(7*7*256, use_bias=False, input_shape=(latent_dim,)),
keras.layers.BatchNormalization(),
keras.layers.LeakyReLU(),
# Reshape to 7x7x256
keras.layers.Reshape((7, 7, 256)),
# Upsample to 14x14x128
keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1),
padding='same', use_bias=False),
keras.layers.BatchNormalization(),
keras.layers.LeakyReLU(),
# Upsample to 14x14x64
keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2),
padding='same', use_bias=False),
keras.layers.BatchNormalization(),
keras.layers.LeakyReLU(),
# Upsample to 28x28x1
keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2),
padding='same', use_bias=False,
activation='tanh')
])
return model
def create_dcgan_discriminator():
"""
Create DCGAN discriminator
"""
model = keras.Sequential([
# 28x28x1 -> 14x14x64
keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
input_shape=[28, 28, 1]),
keras.layers.LeakyReLU(),
keras.layers.Dropout(0.3),
# 14x14x64 -> 7x7x128
keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
keras.layers.LeakyReLU(),
keras.layers.Dropout(0.3),
# Flatten and output
keras.layers.Flatten(),
keras.layers.Dense(1)
])
return model
# Create DCGAN model
dcgan_generator = create_dcgan_generator(100)
dcgan_discriminator = create_dcgan_discriminator()
print("DCGAN Generator:")
dcgan_generator.summary()
print("\nDCGAN Discriminator:")
dcgan_discriminator.summary()Loss Functions
python
# Binary cross-entropy loss
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_output, fake_output):
"""
Discriminator loss function
"""
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(fake_output):
"""
Generator loss function
"""
return cross_entropy(tf.ones_like(fake_output), fake_output)
# WGAN loss function
def wasserstein_discriminator_loss(real_output, fake_output):
"""
Wasserstein discriminator loss
"""
return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)
def wasserstein_generator_loss(fake_output):
"""
Wasserstein generator loss
"""
return -tf.reduce_mean(fake_output)
# LSGAN loss function
def lsgan_discriminator_loss(real_output, fake_output):
"""
LSGAN discriminator loss
"""
real_loss = tf.reduce_mean(tf.square(real_output - 1))
fake_loss = tf.reduce_mean(tf.square(fake_output))
return 0.5 * (real_loss + fake_loss)
def lsgan_generator_loss(fake_output):
"""
LSGAN generator loss
"""
return 0.5 * tf.reduce_mean(tf.square(fake_output - 1))Training Loop
python
class GAN:
def __init__(self, generator, discriminator, latent_dim):
self.generator = generator
self.discriminator = discriminator
self.latent_dim = latent_dim
# Optimizers
self.generator_optimizer = keras.optimizers.Adam(1e-4)
self.discriminator_optimizer = keras.optimizers.Adam(1e-4)
# Loss tracking
self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")
@tf.function
def train_step(self, real_images):
batch_size = tf.shape(real_images)[0]
# Generate random noise
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Train discriminator
with tf.GradientTape() as disc_tape:
# Generate fake images
fake_images = self.generator(random_latent_vectors, training=True)
# Discriminator predictions
real_predictions = self.discriminator(real_images, training=True)
fake_predictions = self.discriminator(fake_images, training=True)
# Calculate discriminator loss
disc_loss = discriminator_loss(real_predictions, fake_predictions)
# Calculate discriminator gradients and update
disc_gradients = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
self.discriminator_optimizer.apply_gradients(
zip(disc_gradients, self.discriminator.trainable_variables)
)
# Train generator
with tf.GradientTape() as gen_tape:
# Generate fake images
fake_images = self.generator(random_latent_vectors, training=True)
# Discriminator predictions
fake_predictions = self.discriminator(fake_images, training=True)
# Calculate generator loss
gen_loss = generator_loss(fake_predictions)
# Calculate generator gradients and update
gen_gradients = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
self.generator_optimizer.apply_gradients(
zip(gen_gradients, self.generator.trainable_variables)
)
# Update loss tracking
self.gen_loss_tracker.update_state(gen_loss)
self.disc_loss_tracker.update_state(disc_loss)
return {
"generator_loss": self.gen_loss_tracker.result(),
"discriminator_loss": self.disc_loss_tracker.result(),
}
# Create GAN instance
gan = GAN(dcgan_generator, dcgan_discriminator, latent_dim=100)Data Preparation and Training
python
def prepare_mnist_data():
"""
Prepare MNIST data
"""
(x_train, _), (_, _) = keras.datasets.mnist.load_data()
# Normalize to [-1, 1]
x_train = x_train.astype('float32')
x_train = (x_train - 127.5) / 127.5
# Add channel dimension
x_train = np.expand_dims(x_train, axis=-1)
return x_train
def create_dataset(data, batch_size=256, buffer_size=60000):
"""
Create training dataset
"""
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.shuffle(buffer_size).batch(batch_size)
return dataset
# Prepare data
train_images = prepare_mnist_data()
train_dataset = create_dataset(train_images)
print(f"Training data shape: {train_images.shape}")
def train_gan(gan, dataset, epochs=100):
"""
Train GAN model
"""
for epoch in range(epochs):
print(f"Epoch {epoch + 1}/{epochs}")
# Train one epoch
for batch in dataset:
losses = gan.train_step(batch)
# Generate samples every 10 epochs
if (epoch + 1) % 10 == 0:
generate_and_save_images(gan.generator, epoch + 1)
print(f"Generator Loss: {losses['generator_loss']:.4f}")
print(f"Discriminator Loss: {losses['discriminator_loss']:.4f}")
def generate_and_save_images(generator, epoch, num_examples=16):
"""
Generate and save images
"""
# Generate random noise
noise = tf.random.normal([num_examples, 100])
# Generate images
generated_images = generator(noise, training=False)
# Visualize
fig = plt.figure(figsize=(4, 4))
for i in range(generated_images.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(generated_images[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.suptitle(f'Epoch {epoch}')
plt.tight_layout()
plt.show()
# Start training (example, actual training needs more epochs)
# train_gan(gan, train_dataset, epochs=50)Conditional GAN (cGAN)
python
def create_conditional_generator(latent_dim, num_classes, img_shape):
"""
Create conditional generator
"""
# Noise input
noise_input = keras.layers.Input(shape=(latent_dim,))
# Label input
label_input = keras.layers.Input(shape=(1,))
label_embedding = keras.layers.Embedding(num_classes, 50)(label_input)
label_embedding = keras.layers.Flatten()(label_embedding)
# Merge noise and label
merged_input = keras.layers.Concatenate()([noise_input, label_embedding])
# Generator network
x = keras.layers.Dense(7*7*256, use_bias=False)(merged_input)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Reshape((7, 7, 256))(x)
x = keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1),
padding='same', use_bias=False)(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2),
padding='same', use_bias=False)(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
output = keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2),
padding='same', use_bias=False,
activation='tanh')(x)
model = keras.Model([noise_input, label_input], output)
return model
def create_conditional_discriminator(img_shape, num_classes):
"""
Create conditional discriminator
"""
# Image input
img_input = keras.layers.Input(shape=img_shape)
# Label input
label_input = keras.layers.Input(shape=(1,))
label_embedding = keras.layers.Embedding(num_classes, 50)(label_input)
label_embedding = keras.layers.Flatten()(label_embedding)
label_embedding = keras.layers.Dense(np.prod(img_shape))(label_embedding)
label_embedding = keras.layers.Reshape(img_shape)(label_embedding)
# Merge image and label
merged_input = keras.layers.Concatenate()([img_input, label_embedding])
# Discriminator network
x = keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')(merged_input)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Dropout(0.3)(x)
x = keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')(x)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Dropout(0.3)(x)
x = keras.layers.Flatten()(x)
output = keras.layers.Dense(1)(x)
model = keras.Model([img_input, label_input], output)
return model
# Create conditional GAN
cond_generator = create_conditional_generator(100, 10, (28, 28, 1))
cond_discriminator = create_conditional_discriminator((28, 28, 1), 10)CycleGAN Implementation
python
def create_cyclegan_generator():
"""
Create CycleGAN generator (ResNet architecture)
"""
def residual_block(x, filters):
shortcut = x
x = keras.layers.Conv2D(filters, 3, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(filters, 3, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Add()([shortcut, x])
x = keras.layers.ReLU()(x)
return x
inputs = keras.layers.Input(shape=(256, 256, 3))
# Encoder
x = keras.layers.Conv2D(64, 7, padding='same')(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
# Downsampling
x = keras.layers.Conv2D(128, 3, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(256, 3, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
# Residual blocks
for _ in range(9):
x = residual_block(x, 256)
# Upsampling
x = keras.layers.Conv2DTranspose(128, 3, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
outputs = keras.layers.Conv2D(3, 7, padding='same', activation='tanh')(x)
model = keras.Model(inputs, outputs)
return model
def create_cyclegan_discriminator():
"""
Create CycleGAN discriminator (PatchGAN)
"""
inputs = keras.layers.Input(shape=(256, 256, 3))
x = keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
x = keras.layers.LeakyReLU(0.2)(x)
x = keras.layers.Conv2D(128, 4, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU(0.2)(x)
x = keras.layers.Conv2D(256, 4, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU(0.2)(x)
x = keras.layers.Conv2D(512, 4, strides=1, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU(0.2)(x)
outputs = keras.layers.Conv2D(1, 4, strides=1, padding='same')(x)
model = keras.Model(inputs, outputs)
return model
# CycleGAN loss functions
def cycle_consistency_loss(real_image, cycled_image, lambda_cycle=10.0):
"""
Cycle consistency loss
"""
loss = tf.reduce_mean(tf.abs(real_image - cycled_image))
return lambda_cycle * loss
def identity_loss(real_image, same_image, lambda_identity=0.5):
"""
Identity loss
"""
loss = tf.reduce_mean(tf.abs(real_image - same_image))
return lambda_identity * lossStyleGAN Basics
python
def create_stylegan_mapping_network(latent_dim=512, num_layers=8):
"""
Create StyleGAN mapping network
"""
model = keras.Sequential()
for _ in range(num_layers):
model.add(keras.layers.Dense(latent_dim, activation='relu'))
model.build(input_shape=(None, latent_dim))
return model
def adaptive_instance_normalization(content_features, style_features):
"""
Adaptive instance normalization
"""
# Calculate mean and std of content features
content_mean = tf.reduce_mean(content_features, axis=[1, 2], keepdims=True)
content_std = tf.math.reduce_std(content_features, axis=[1, 2], keepdims=True)
# Calculate mean and std of style features
style_mean = tf.reduce_mean(style_features, axis=[1, 2], keepdims=True)
style_std = tf.math.reduce_std(style_features, axis=[1, 2], keepdims=True)
# Normalize content features
normalized_content = (content_features - content_mean) / (content_std + 1e-8)
# Apply style statistics
stylized_features = normalized_content * style_std + style_mean
return stylized_features
class StyleGANGenerator(keras.Model):
"""
Simplified StyleGAN generator
"""
def __init__(self, latent_dim=512, num_layers=6):
super(StyleGANGenerator, self).__init__()
self.latent_dim = latent_dim
self.num_layers = num_layers
# Mapping network
self.mapping_network = create_stylegan_mapping_network(latent_dim)
# Constant input
self.constant_input = self.add_weight(
shape=(1, 4, 4, 512),
initializer='random_normal',
trainable=True,
name='constant_input'
)
# Generator layers
self.conv_layers = []
self.adain_layers = []
channels = [512, 512, 256, 128, 64, 32]
for i in range(num_layers):
self.conv_layers.append(
keras.layers.Conv2D(channels[i], 3, padding='same', activation='relu')
)
self.adain_layers.append(
keras.layers.Dense(channels[i] * 2) # For generating mean and std
)
def call(self, latent_codes):
batch_size = tf.shape(latent_codes)[0]
# Mapping network
w = self.mapping_network(latent_codes)
# Start from constant
x = tf.tile(self.constant_input, [batch_size, 1, 1, 1])
# Generate layer by layer
for i in range(self.num_layers):
# Convolution
x = self.conv_layers[i](x)
# AdaIN
style_params = self.adain_layers[i](w)
style_mean, style_std = tf.split(style_params, 2, axis=-1)
# Apply AdaIN
x_mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
x_std = tf.math.reduce_std(x, axis=[1, 2], keepdims=True)
x = (x - x_mean) / (x_std + 1e-8)
x = x * tf.expand_dims(tf.expand_dims(style_std, 1), 1) + \
tf.expand_dims(tf.expand_dims(style_mean, 1), 1)
# Upsampling (except last layer)
if i < self.num_layers - 1:
x = keras.layers.UpSampling2D()(x)
return xEvaluation Metrics
python
def calculate_fid_score(real_images, generated_images, model_name='inception_v3'):
"""
Calculate FID score (Fréchet Inception Distance)
"""
# Load pre-trained Inception model
inception_model = keras.applications.InceptionV3(
include_top=False,
pooling='avg',
input_shape=(299, 299, 3)
)
def preprocess_images(images):
# Resize images to 299x299
images = tf.image.resize(images, [299, 299])
# Preprocess
images = keras.applications.inception_v3.preprocess_input(images)
return images
# Preprocess images
real_images = preprocess_images(real_images)
generated_images = preprocess_images(generated_images)
# Extract features
real_features = inception_model.predict(real_images)
generated_features = inception_model.predict(generated_images)
# Calculate mean and covariance
mu_real = np.mean(real_features, axis=0)
sigma_real = np.cov(real_features, rowvar=False)
mu_gen = np.mean(generated_features, axis=0)
sigma_gen = np.cov(generated_features, rowvar=False)
# Calculate FID
diff = mu_real - mu_gen
covmean = scipy.linalg.sqrtm(sigma_real.dot(sigma_gen))
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff.dot(diff) + np.trace(sigma_real + sigma_gen - 2 * covmean)
return fid
def calculate_inception_score(generated_images, num_splits=10):
"""
Calculate Inception Score
"""
# Load Inception model
inception_model = keras.applications.InceptionV3(
include_top=True,
input_shape=(299, 299, 3)
)
# Preprocess images
images = tf.image.resize(generated_images, [299, 299])
images = keras.applications.inception_v3.preprocess_input(images)
# Get predictions
predictions = inception_model.predict(images)
# Calculate IS
scores = []
for i in range(num_splits):
part = predictions[i * len(predictions) // num_splits:
(i + 1) * len(predictions) // num_splits]
# Calculate KL divergence
py = np.mean(part, axis=0)
kl_div = part * (np.log(part + 1e-8) - np.log(py + 1e-8))
kl_div = np.mean(np.sum(kl_div, axis=1))
scores.append(np.exp(kl_div))
return np.mean(scores), np.std(scores)Training Tips and Best Practices
python
# 1. Progressive training
class ProgressiveGAN:
def __init__(self):
self.current_resolution = 4
self.max_resolution = 256
def grow_network(self):
"""
Gradually increase network resolution
"""
if self.current_resolution < self.max_resolution:
self.current_resolution *= 2
# Add new layers to generator and discriminator
def fade_in_new_layers(self, alpha):
"""
Fade in new layers
"""
# Use alpha to blend old and new outputs
pass
# 2. Spectral normalization
def spectral_normalization(layer):
"""
Spectral normalization decorator
"""
return keras.utils.get_custom_objects()['SpectralNormalization'](layer)
# 3. Self-attention mechanism
class SelfAttention(keras.layers.Layer):
def __init__(self, channels):
super(SelfAttention, self).__init__()
self.channels = channels
self.query_conv = keras.layers.Conv2D(channels // 8, 1)
self.key_conv = keras.layers.Conv2D(channels // 8, 1)
self.value_conv = keras.layers.Conv2D(channels, 1)
self.gamma = self.add_weight(shape=(), initializer='zeros', trainable=True)
def call(self, x):
batch_size, height, width, channels = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
# Calculate query, key, value
query = self.query_conv(x)
key = self.key_conv(x)
value = self.value_conv(x)
# Reshape to matrix form
query = tf.reshape(query, [batch_size, -1, channels // 8])
key = tf.reshape(key, [batch_size, -1, channels // 8])
value = tf.reshape(value, [batch_size, -1, channels])
# Calculate attention
attention = tf.nn.softmax(tf.matmul(query, key, transpose_b=True))
out = tf.matmul(attention, value)
out = tf.reshape(out, [batch_size, height, width, channels])
# Residual connection
out = self.gamma * out + x
return out
# 4. Gradient penalty (WGAN-GP)
def gradient_penalty(discriminator, real_images, fake_images, batch_size):
"""
Calculate gradient penalty
"""
alpha = tf.random.uniform([batch_size, 1, 1, 1], 0., 1.)
interpolated = alpha * real_images + (1 - alpha) * fake_images
with tf.GradientTape() as tape:
tape.watch(interpolated)
pred = discriminator(interpolated, training=True)
gradients = tape.gradient(pred, interpolated)
gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
gradient_penalty = tf.reduce_mean((gradients_norm - 1.) ** 2)
return gradient_penaltyReal Application Examples
python
def image_to_image_translation():
"""
Image-to-image translation example
"""
# Create Pix2Pix model
def create_pix2pix_generator():
# U-Net architecture
inputs = keras.layers.Input(shape=(256, 256, 3))
# Encoder
down_stack = [
keras.layers.Conv2D(64, 4, strides=2, padding='same', use_bias=False),
keras.layers.Conv2D(128, 4, strides=2, padding='same', use_bias=False),
keras.layers.Conv2D(256, 4, strides=2, padding='same', use_bias=False),
keras.layers.Conv2D(512, 4, strides=2, padding='same', use_bias=False),
]
# Decoder
up_stack = [
keras.layers.Conv2DTranspose(256, 4, strides=2, padding='same', use_bias=False),
keras.layers.Conv2DTranspose(128, 4, strides=2, padding='same', use_bias=False),
keras.layers.Conv2DTranspose(64, 4, strides=2, padding='same', use_bias=False),
]
x = inputs
# Downsampling
skips = []
for down in down_stack:
x = down(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
skips.append(x)
skips = reversed(skips[:-1])
# Upsampling
for up, skip in zip(up_stack, skips):
x = up(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Concatenate()([x, skip])
# Last layer
last = keras.layers.Conv2DTranspose(3, 4, strides=2, padding='same', activation='tanh')
x = last(x)
return keras.Model(inputs=inputs, outputs=x)
return create_pix2pix_generator()
def super_resolution_gan():
"""
Super-resolution GAN example
"""
def create_srgan_generator():
def residual_block(x):
shortcut = x
x = keras.layers.Conv2D(64, 3, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.PReLU()(x)
x = keras.layers.Conv2D(64, 3, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Add()([shortcut, x])
return x
inputs = keras.layers.Input(shape=(None, None, 3))
# Initial convolution
x = keras.layers.Conv2D(64, 9, padding='same')(inputs)
x = keras.layers.PReLU()(x)
# Residual blocks
for _ in range(16):
x = residual_block(x)
# Upsampling
x = keras.layers.Conv2D(256, 3, padding='same')(x)
x = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)
x = keras.layers.PReLU()(x)
x = keras.layers.Conv2D(256, 3, padding='same')(x)
x = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)
x = keras.layers.PReLU()(x)
# Output
outputs = keras.layers.Conv2D(3, 9, padding='same', activation='tanh')(x)
return keras.Model(inputs, outputs)
return create_srgan_generator()Summary
GAN is one of the most innovative technologies in deep learning, ushering in a new era of generative models. From basic GAN to advanced variants like StyleGAN and BigGAN, GANs have wide applications in image generation, style transfer, data augmentation, and other fields.
Key Points:
- Adversarial Training: The game process between generator and discriminator
- Loss Function Design: Different loss functions for different scenarios
- Training Stability: Requires careful tuning of hyperparameters and training strategies
- Evaluation Metrics: FID, IS, and other metrics help evaluate generation quality
Next chapter we will learn practical projects and apply theoretical knowledge to specific projects.