Skip to content

PyTorch Model Optimization

Optimization Overview

Model optimization is a critical aspect of deep learning projects, involving training efficiency, inference speed, memory usage, and model size. This chapter will introduce various optimization techniques in PyTorch.

Training Optimization

1. Mixed Precision Training

python
import torch
from torch.cuda.amp import GradScaler, autocast
import torch.nn as nn

class MixedPrecisionTrainer:
    def __init__(self, model, optimizer, criterion):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.scaler = GradScaler()
    
    def train_step(self, data, target):
        """Mixed precision training step"""
        self.optimizer.zero_grad()
        
        # Use autocast for forward pass
        with autocast():
            output = self.model(data)
            loss = self.criterion(output, target)
        
        # Scale loss and backpropagate
        self.scaler.scale(loss).backward()
        
        # Gradient clipping
        self.scaler.unscale_(self.optimizer)
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        
        # Update parameters
        self.scaler.step(self.optimizer)
        self.scaler.update()
        
        return loss.item()

# Usage example
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

trainer = MixedPrecisionTrainer(model, optimizer, criterion)

# Training loop
for data, target in dataloader:
    loss = trainer.train_step(data, target)

2. Gradient Accumulation

python
class GradientAccumulator:
    def __init__(self, model, optimizer, criterion, accumulation_steps=4):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.accumulation_steps = accumulation_steps
        self.step_count = 0
    
    def accumulate_step(self, data, target):
        """Gradient accumulation step"""
        # Forward pass
        output = self.model(data)
        loss = self.criterion(output, target)
        
        # Scale loss
        loss = loss / self.accumulation_steps
        
        # Backward pass
        loss.backward()
        
        self.step_count += 1
        
        # Update parameters every accumulation_steps steps
        if self.step_count % self.accumulation_steps == 0:
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            # Update parameters
            self.optimizer.step()
            self.optimizer.zero_grad()
        
        return loss.item() * self.accumulation_steps

# Usage example
accumulator = GradientAccumulator(model, optimizer, criterion, accumulation_steps=8)

for data, target in dataloader:
    loss = accumulator.accumulate_step(data, target)

3. Learning Rate Scheduling Optimization

python
import torch.optim.lr_scheduler as lr_scheduler
import math

class CosineAnnealingWarmRestarts(lr_scheduler._LRScheduler):
    """Cosine annealing scheduler with warmup"""
    def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, warmup_steps=0, last_epoch=-1):
        self.T_0 = T_0
        self.T_mult = T_mult
        self.eta_min = eta_min
        self.warmup_steps = warmup_steps
        super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            # Warmup phase
            return [base_lr * (self.last_epoch + 1) / self.warmup_steps 
                    for base_lr in self.base_lrs]
        else:
            # Cosine annealing phase
            adjusted_epoch = self.last_epoch - self.warmup_steps
            T_cur = adjusted_epoch % self.T_0
            return [self.eta_min + (base_lr - self.eta_min) * 
                    (1 + math.cos(math.pi * T_cur / self.T_0)) / 2
                    for base_lr in self.base_lrs]

# Usage example
scheduler = CosineAnnealingWarmRestarts(
    optimizer, T_0=50, T_mult=2, eta_min=1e-6, warmup_steps=10
)

# In training loop
for epoch in range(num_epochs):
    for data, target in dataloader:
        # Training step
        pass
    scheduler.step()

Memory Optimization

1. Gradient Checkpointing

python
import torch.utils.checkpoint as checkpoint

class CheckpointedModel(nn.Module):
    def __init__(self, layers):
        super(CheckpointedModel, self).__init__()
        self.layers = nn.ModuleList(layers)
    
    def forward(self, x):
        # Use gradient checkpointing to save memory
        for layer in self.layers:
            x = checkpoint.checkpoint(layer, x)
        return x

# Or use decorator
def checkpointed_forward(module, input):
    return checkpoint.checkpoint(module, input)

# In large models
class LargeTransformer(nn.Module):
    def __init__(self, config):
        super(LargeTransformer, self).__init__()
        self.layers = nn.ModuleList([
            TransformerLayer(config) for _ in range(config.num_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            # Use checkpointing to save memory
            x = checkpoint.checkpoint(layer, x)
        return x

2. Memory-Mapped Datasets

python
import mmap
import numpy as np
from torch.utils.data import Dataset

class MemoryMappedDataset(Dataset):
    def __init__(self, data_file, index_file):
        # Use memory mapping to read large files
        self.data_file = open(data_file, 'rb')
        self.data_mmap = mmap.mmap(self.data_file.fileno(), 0, access=mmap.ACCESS_READ)
        
        # Load indices
        self.indices = np.load(index_file)
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        offset, size = self.indices[idx]
        
        # Read data from memory map
        self.data_mmap.seek(offset)
        data_bytes = self.data_mmap.read(size)
        
        # Parse data
        data = self._parse_data(data_bytes)
        
        return data
    
    def _parse_data(self, data_bytes):
        # Implement specific data parsing logic
        pass
    
    def __del__(self):
        if hasattr(self, 'data_mmap'):
            self.data_mmap.close()
        if hasattr(self, 'data_file'):
            self.data_file.close()

3. Dynamic Batch Size

python
class DynamicBatchSampler:
    def __init__(self, dataset, max_tokens=4096, max_batch_size=32):
        self.dataset = dataset
        self.max_tokens = max_tokens
        self.max_batch_size = max_batch_size
    
    def __iter__(self):
        batch = []
        current_tokens = 0
        
        for idx in range(len(self.dataset)):
            sample_length = len(self.dataset[idx])
            
            # Check if limits are exceeded
            if (current_tokens + sample_length > self.max_tokens or 
                len(batch) >= self.max_batch_size) and batch:
                yield batch
                batch = []
                current_tokens = 0
            
            batch.append(idx)
            current_tokens += sample_length
        
        if batch:
            yield batch

# Use dynamic batch sampler
sampler = DynamicBatchSampler(dataset, max_tokens=4096)
dataloader = DataLoader(dataset, batch_sampler=sampler)

Compute Optimization

1. Model Compilation (PyTorch 2.0+)

python
import torch._dynamo as dynamo

# Compile model for better performance
@torch.compile
class OptimizedModel(nn.Module):
    def __init__(self, config):
        super(OptimizedModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(config.input_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, config.output_size)
        )
    
    def forward(self, x):
        return self.layers(x)

# Or compile existing model
model = MyModel()
compiled_model = torch.compile(model, mode='max-autotune')

# Different compilation modes
# 'default': Balance compile time and runtime performance
# 'reduce-overhead': Reduce Python overhead
# 'max-autotune': Maximize performance optimization

2. Operator Fusion

python
# Manually fuse common operations
class FusedLinearReLU(nn.Module):
    def __init__(self, in_features, out_features):
        super(FusedLinearReLU, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
    
    def forward(self, x):
        # Fuse linear transformation and ReLU activation
        return torch.relu(self.linear(x))

# Use TorchScript for automatic fusion
class ModelForFusion(nn.Module):
    def __init__(self):
        super(ModelForFusion, self).__init__()
        self.conv = nn.Conv2d(3, 64, 3, padding=1)
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# Script model to enable fusion
model = ModelForFusion()
scripted_model = torch.jit.script(model)

# Freeze model for inference optimization
scripted_model.eval()
frozen_model = torch.jit.freeze(scripted_model)

3. Parallel Computing Optimization

python
# Data parallel
model = nn.DataParallel(model)

# Distributed data parallel
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_distributed():
    dist.init_process_group(backend='nccl')
    torch.cuda.set_device(int(os.environ['LOCAL_RANK']))

def cleanup_distributed():
    dist.destroy_process_group()

# Use DDP
model = DDP(model, device_ids=[local_rank])

# Model parallel (for large models)
class ModelParallelNet(nn.Module):
    def __init__(self):
        super(ModelParallelNet, self).__init__()
        self.layer1 = nn.Linear(1000, 1000).to('cuda:0')
        self.layer2 = nn.Linear(1000, 1000).to('cuda:1')
        self.layer3 = nn.Linear(1000, 10).to('cuda:1')
    
    def forward(self, x):
        x = x.to('cuda:0')
        x = self.layer1(x)
        x = x.to('cuda:1')
        x = self.layer2(x)
        x = self.layer3(x)
        return x

Model Compression

1. Knowledge Distillation

python
class KnowledgeDistillation:
    def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7):
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.temperature = temperature
        self.alpha = alpha
        
        # Freeze teacher model
        for param in self.teacher_model.parameters():
            param.requires_grad = False
    
    def distillation_loss(self, student_logits, teacher_logits, true_labels):
        """Compute distillation loss"""
        # Soft label loss
        soft_loss = nn.KLDivLoss(reduction='batchmean')(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2)
        
        # Hard label loss
        hard_loss = F.cross_entropy(student_logits, true_labels)
        
        # Combined loss
        total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
        
        return total_loss
    
    def train_step(self, data, target, optimizer):
        """Distillation training step"""
        self.teacher_model.eval()
        self.student_model.train()
        
        optimizer.zero_grad()
        
        # Teacher model prediction
        with torch.no_grad():
            teacher_logits = self.teacher_model(data)
        
        # Student model prediction
        student_logits = self.student_model(data)
        
        # Compute distillation loss
        loss = self.distillation_loss(student_logits, teacher_logits, target)
        
        loss.backward()
        optimizer.step()
        
        return loss.item()

# Usage example
teacher = LargeModel()  # Large teacher model
student = SmallModel()  # Small student model

distiller = KnowledgeDistillation(teacher, student)
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)

for data, target in dataloader:
    loss = distiller.train_step(data, target, optimizer)

2. Model Pruning

python
import torch.nn.utils.prune as prune

class ModelPruner:
    def __init__(self, model):
        self.model = model
    
    def structured_pruning(self, pruning_ratio=0.2):
        """Structured pruning"""
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                # Prune by channels
                prune.ln_structured(
                    module, name='weight', amount=pruning_ratio, 
                    n=2, dim=0  # Prune output channels
                )
            elif isinstance(module, nn.Linear):
                # Prune by neurons
                prune.ln_structured(
                    module, name='weight', amount=pruning_ratio,
                    n=2, dim=0
                )
    
    def unstructured_pruning(self, pruning_ratio=0.2):
        """Unstructured pruning"""
        parameters_to_prune = []
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                parameters_to_prune.append((module, 'weight'))
        
        # Global unstructured pruning
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=pruning_ratio,
        )
    
    def gradual_pruning(self, initial_sparsity=0.0, final_sparsity=0.8, 
                       pruning_steps=100, pruning_frequency=10):
        """Gradual pruning"""
        current_step = 0
        
        for epoch in range(pruning_steps):
            if epoch % pruning_frequency == 0:
                # Compute current sparsity
                current_sparsity = initial_sparsity + (
                    final_sparsity - initial_sparsity
                ) * (current_step / pruning_steps)
                
                # Apply pruning
                self.unstructured_pruning(current_sparsity)
                current_step += 1
    
    def remove_pruning(self):
        """Remove pruning reparameterization"""
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                try:
                    prune.remove(module, 'weight')
                except:
                    pass
    
    def calculate_sparsity(self):
        """Calculate model sparsity"""
        total_params = 0
        zero_params = 0
        
        for param in self.model.parameters():
            total_params += param.numel()
            zero_params += (param == 0).sum().item()
        
        sparsity = zero_params / total_params
        return sparsity

# Usage example
pruner = ModelPruner(model)

# Apply pruning
pruner.unstructured_pruning(pruning_ratio=0.3)

# Calculate sparsity
sparsity = pruner.calculate_sparsity()
print(f"Model sparsity: {sparsity:.2%}")

# Fine-tune pruned model
for epoch in range(fine_tune_epochs):
    # Training loop
    pass

# Remove pruning reparameterization
pruner.remove_pruning()

3. Quantization

python
import torch.quantization as quantization

class ModelQuantizer:
    def __init__(self, model):
        self.model = model
    
    def post_training_quantization(self, calibration_loader):
        """Post-training quantization"""
        # Set quantization configuration
        self.model.qconfig = quantization.get_default_qconfig('fbgemm')
        
        # Prepare for quantization
        model_prepared = quantization.prepare(self.model, inplace=False)
        
        # Calibrate
        model_prepared.eval()
        with torch.no_grad():
            for data, _ in calibration_loader:
                model_prepared(data)
        
        # Convert to quantized model
        model_quantized = quantization.convert(model_prepared, inplace=False)
        
        return model_quantized
    
    def quantization_aware_training(self, train_loader, num_epochs=5):
        """Quantization-aware training"""
        # Set QAT configuration
        self.model.qconfig = quantization.get_default_qat_qconfig('fbgemm')
        
        # Prepare QAT
        model_prepared = quantization.prepare_qat(self.model, inplace=False)
        
        # QAT training
        optimizer = torch.optim.Adam(model_prepared.parameters(), lr=0.0001)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(num_epochs):
            model_prepared.train()
            for data, target in train_loader:
                optimizer.zero_grad()
                output = model_prepared(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
        
        # Convert to quantized model
        model_prepared.eval()
        model_quantized = quantization.convert(model_prepared, inplace=False)
        
        return model_quantized
    
    def dynamic_quantization(self):
        """Dynamic quantization"""
        model_quantized = quantization.quantize_dynamic(
            self.model, {nn.Linear}, dtype=torch.qint8
        )
        return model_quantized

# Usage example
quantizer = ModelQuantizer(model)

# Dynamic quantization (simplest)
quantized_model = quantizer.dynamic_quantization()

# Post-training quantization
# quantized_model = quantizer.post_training_quantization(calibration_loader)

# Quantization-aware training
# quantized_model = quantizer.quantization_aware_training(train_loader)

# Compare model sizes
def get_model_size(model):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p")
    os.remove("temp.p")
    return size

original_size = get_model_size(model)
quantized_size = get_model_size(quantized_model)

print(f"Original model size: {original_size / 1024 / 1024:.2f} MB")
print(f"Quantized model size: {quantized_size / 1024 / 1024:.2f} MB")
print(f"Compression ratio: {original_size / quantized_size:.2f}x")

Inference Optimization

1. TorchScript Optimization

python
# Optimize TorchScript model
def optimize_torchscript_model(model, example_input):
    """Optimize TorchScript model"""
    model.eval()
    
    # Trace model
    traced_model = torch.jit.trace(model, example_input)
    
    # Optimize
    optimized_model = torch.jit.optimize_for_inference(traced_model)
    
    # Freeze model
    frozen_model = torch.jit.freeze(optimized_model)
    
    return frozen_model

# Usage example
example_input = torch.randn(1, 3, 224, 224)
optimized_model = optimize_torchscript_model(model, example_input)

# Save optimized model
optimized_model.save('optimized_model.pt')

2. Batch Inference Optimization

python
class BatchInferenceOptimizer:
    def __init__(self, model, max_batch_size=32, timeout=0.1):
        self.model = model
        self.max_batch_size = max_batch_size
        self.timeout = timeout
        self.batch_queue = []
        
    async def predict(self, input_data):
        """Asynchronous batch inference"""
        import asyncio
        from concurrent.futures import Future
        
        future = Future()
        self.batch_queue.append((input_data, future))
        
        # Check if batch needs processing
        if len(self.batch_queue) >= self.max_batch_size:
            await self._process_batch()
        else:
            # Set timeout handling
            asyncio.create_task(self._timeout_handler())
        
        return await asyncio.wrap_future(future)
    
    async def _process_batch(self):
        """Process batch data"""
        if not self.batch_queue:
            return
        
        # Collect batch data
        batch_data = []
        futures = []
        
        for data, future in self.batch_queue:
            batch_data.append(data)
            futures.append(future)
        
        self.batch_queue.clear()
        
        # Batch inference
        try:
            batch_input = torch.stack(batch_data)
            with torch.no_grad():
                batch_output = self.model(batch_input)
            
            # Distribute results
            for i, future in enumerate(futures):
                future.set_result(batch_output[i])
        
        except Exception as e:
            for future in futures:
                future.set_exception(e)
    
    async def _timeout_handler(self):
        """Timeout handling"""
        import asyncio
        await asyncio.sleep(self.timeout)
        if self.batch_queue:
            await self._process_batch()

Performance Monitoring

1. Profiler

python
import torch.profiler as profiler

def profile_model(model, input_data, num_steps=100):
    """Profile model performance"""
    model.eval()
    
    with profiler.profile(
        activities=[
            profiler.ProfilerActivity.CPU,
            profiler.ProfilerActivity.CUDA,
        ],
        schedule=profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
        on_trace_ready=profiler.tensorboard_trace_handler('./log/profiler'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as prof:
        for step in range(num_steps):
            with torch.no_grad():
                output = model(input_data)
            prof.step()
    
    # Print performance report
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    
    return prof

# Usage example
input_data = torch.randn(32, 3, 224, 224).cuda()
prof = profile_model(model.cuda(), input_data)

2. Memory Analysis

python
def analyze_memory_usage(model, input_data):
    """Analyze memory usage"""
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    # Record initial memory
    initial_memory = torch.cuda.memory_allocated()
    
    # Forward pass
    model.eval()
    with torch.no_grad():
        output = model(input_data)
    
    # Record peak memory
    peak_memory = torch.cuda.max_memory_allocated()
    final_memory = torch.cuda.memory_allocated()
    
    print(f"Initial memory: {initial_memory / 1024**2:.2f} MB")
    print(f"Peak memory: {peak_memory / 1024**2:.2f} MB")
    print(f"Final memory: {final_memory / 1024**2:.2f} MB")
    print(f"Memory growth: {(final_memory - initial_memory) / 1024**2:.2f} MB")
    
    return {
        'initial': initial_memory,
        'peak': peak_memory,
        'final': final_memory
    }

# Usage example
memory_stats = analyze_memory_usage(model.cuda(), input_data.cuda())

Summary

PyTorch model optimization covers all aspects of training and inference:

  1. Training Optimization: Mixed precision, gradient accumulation, learning rate scheduling
  2. Memory Optimization: Gradient checkpointing, memory mapping, dynamic batching
  3. Compute Optimization: Model compilation, operator fusion, parallel computing
  4. Model Compression: Knowledge distillation, pruning, quantization
  5. Inference Optimization: TorchScript, batch inference
  6. Performance Monitoring: Profiling, memory analysis

Mastering these optimization techniques will significantly improve the efficiency and performance of your deep learning projects!

Content is for learning and research only.