PyTorch Model Optimization
Optimization Overview
Model optimization is a critical aspect of deep learning projects, involving training efficiency, inference speed, memory usage, and model size. This chapter will introduce various optimization techniques in PyTorch.
Training Optimization
1. Mixed Precision Training
python
import torch
from torch.cuda.amp import GradScaler, autocast
import torch.nn as nn
class MixedPrecisionTrainer:
def __init__(self, model, optimizer, criterion):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.scaler = GradScaler()
def train_step(self, data, target):
"""Mixed precision training step"""
self.optimizer.zero_grad()
# Use autocast for forward pass
with autocast():
output = self.model(data)
loss = self.criterion(output, target)
# Scale loss and backpropagate
self.scaler.scale(loss).backward()
# Gradient clipping
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Update parameters
self.scaler.step(self.optimizer)
self.scaler.update()
return loss.item()
# Usage example
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
trainer = MixedPrecisionTrainer(model, optimizer, criterion)
# Training loop
for data, target in dataloader:
loss = trainer.train_step(data, target)2. Gradient Accumulation
python
class GradientAccumulator:
def __init__(self, model, optimizer, criterion, accumulation_steps=4):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.accumulation_steps = accumulation_steps
self.step_count = 0
def accumulate_step(self, data, target):
"""Gradient accumulation step"""
# Forward pass
output = self.model(data)
loss = self.criterion(output, target)
# Scale loss
loss = loss / self.accumulation_steps
# Backward pass
loss.backward()
self.step_count += 1
# Update parameters every accumulation_steps steps
if self.step_count % self.accumulation_steps == 0:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Update parameters
self.optimizer.step()
self.optimizer.zero_grad()
return loss.item() * self.accumulation_steps
# Usage example
accumulator = GradientAccumulator(model, optimizer, criterion, accumulation_steps=8)
for data, target in dataloader:
loss = accumulator.accumulate_step(data, target)3. Learning Rate Scheduling Optimization
python
import torch.optim.lr_scheduler as lr_scheduler
import math
class CosineAnnealingWarmRestarts(lr_scheduler._LRScheduler):
"""Cosine annealing scheduler with warmup"""
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, warmup_steps=0, last_epoch=-1):
self.T_0 = T_0
self.T_mult = T_mult
self.eta_min = eta_min
self.warmup_steps = warmup_steps
super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
# Warmup phase
return [base_lr * (self.last_epoch + 1) / self.warmup_steps
for base_lr in self.base_lrs]
else:
# Cosine annealing phase
adjusted_epoch = self.last_epoch - self.warmup_steps
T_cur = adjusted_epoch % self.T_0
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * T_cur / self.T_0)) / 2
for base_lr in self.base_lrs]
# Usage example
scheduler = CosineAnnealingWarmRestarts(
optimizer, T_0=50, T_mult=2, eta_min=1e-6, warmup_steps=10
)
# In training loop
for epoch in range(num_epochs):
for data, target in dataloader:
# Training step
pass
scheduler.step()Memory Optimization
1. Gradient Checkpointing
python
import torch.utils.checkpoint as checkpoint
class CheckpointedModel(nn.Module):
def __init__(self, layers):
super(CheckpointedModel, self).__init__()
self.layers = nn.ModuleList(layers)
def forward(self, x):
# Use gradient checkpointing to save memory
for layer in self.layers:
x = checkpoint.checkpoint(layer, x)
return x
# Or use decorator
def checkpointed_forward(module, input):
return checkpoint.checkpoint(module, input)
# In large models
class LargeTransformer(nn.Module):
def __init__(self, config):
super(LargeTransformer, self).__init__()
self.layers = nn.ModuleList([
TransformerLayer(config) for _ in range(config.num_layers)
])
def forward(self, x):
for layer in self.layers:
# Use checkpointing to save memory
x = checkpoint.checkpoint(layer, x)
return x2. Memory-Mapped Datasets
python
import mmap
import numpy as np
from torch.utils.data import Dataset
class MemoryMappedDataset(Dataset):
def __init__(self, data_file, index_file):
# Use memory mapping to read large files
self.data_file = open(data_file, 'rb')
self.data_mmap = mmap.mmap(self.data_file.fileno(), 0, access=mmap.ACCESS_READ)
# Load indices
self.indices = np.load(index_file)
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
offset, size = self.indices[idx]
# Read data from memory map
self.data_mmap.seek(offset)
data_bytes = self.data_mmap.read(size)
# Parse data
data = self._parse_data(data_bytes)
return data
def _parse_data(self, data_bytes):
# Implement specific data parsing logic
pass
def __del__(self):
if hasattr(self, 'data_mmap'):
self.data_mmap.close()
if hasattr(self, 'data_file'):
self.data_file.close()3. Dynamic Batch Size
python
class DynamicBatchSampler:
def __init__(self, dataset, max_tokens=4096, max_batch_size=32):
self.dataset = dataset
self.max_tokens = max_tokens
self.max_batch_size = max_batch_size
def __iter__(self):
batch = []
current_tokens = 0
for idx in range(len(self.dataset)):
sample_length = len(self.dataset[idx])
# Check if limits are exceeded
if (current_tokens + sample_length > self.max_tokens or
len(batch) >= self.max_batch_size) and batch:
yield batch
batch = []
current_tokens = 0
batch.append(idx)
current_tokens += sample_length
if batch:
yield batch
# Use dynamic batch sampler
sampler = DynamicBatchSampler(dataset, max_tokens=4096)
dataloader = DataLoader(dataset, batch_sampler=sampler)Compute Optimization
1. Model Compilation (PyTorch 2.0+)
python
import torch._dynamo as dynamo
# Compile model for better performance
@torch.compile
class OptimizedModel(nn.Module):
def __init__(self, config):
super(OptimizedModel, self).__init__()
self.layers = nn.Sequential(
nn.Linear(config.input_size, config.hidden_size),
nn.ReLU(),
nn.Linear(config.hidden_size, config.output_size)
)
def forward(self, x):
return self.layers(x)
# Or compile existing model
model = MyModel()
compiled_model = torch.compile(model, mode='max-autotune')
# Different compilation modes
# 'default': Balance compile time and runtime performance
# 'reduce-overhead': Reduce Python overhead
# 'max-autotune': Maximize performance optimization2. Operator Fusion
python
# Manually fuse common operations
class FusedLinearReLU(nn.Module):
def __init__(self, in_features, out_features):
super(FusedLinearReLU, self).__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, x):
# Fuse linear transformation and ReLU activation
return torch.relu(self.linear(x))
# Use TorchScript for automatic fusion
class ModelForFusion(nn.Module):
def __init__(self):
super(ModelForFusion, self).__init__()
self.conv = nn.Conv2d(3, 64, 3, padding=1)
self.bn = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
# Script model to enable fusion
model = ModelForFusion()
scripted_model = torch.jit.script(model)
# Freeze model for inference optimization
scripted_model.eval()
frozen_model = torch.jit.freeze(scripted_model)3. Parallel Computing Optimization
python
# Data parallel
model = nn.DataParallel(model)
# Distributed data parallel
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_distributed():
dist.init_process_group(backend='nccl')
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
def cleanup_distributed():
dist.destroy_process_group()
# Use DDP
model = DDP(model, device_ids=[local_rank])
# Model parallel (for large models)
class ModelParallelNet(nn.Module):
def __init__(self):
super(ModelParallelNet, self).__init__()
self.layer1 = nn.Linear(1000, 1000).to('cuda:0')
self.layer2 = nn.Linear(1000, 1000).to('cuda:1')
self.layer3 = nn.Linear(1000, 10).to('cuda:1')
def forward(self, x):
x = x.to('cuda:0')
x = self.layer1(x)
x = x.to('cuda:1')
x = self.layer2(x)
x = self.layer3(x)
return xModel Compression
1. Knowledge Distillation
python
class KnowledgeDistillation:
def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7):
self.teacher_model = teacher_model
self.student_model = student_model
self.temperature = temperature
self.alpha = alpha
# Freeze teacher model
for param in self.teacher_model.parameters():
param.requires_grad = False
def distillation_loss(self, student_logits, teacher_logits, true_labels):
"""Compute distillation loss"""
# Soft label loss
soft_loss = nn.KLDivLoss(reduction='batchmean')(
F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1)
) * (self.temperature ** 2)
# Hard label loss
hard_loss = F.cross_entropy(student_logits, true_labels)
# Combined loss
total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
return total_loss
def train_step(self, data, target, optimizer):
"""Distillation training step"""
self.teacher_model.eval()
self.student_model.train()
optimizer.zero_grad()
# Teacher model prediction
with torch.no_grad():
teacher_logits = self.teacher_model(data)
# Student model prediction
student_logits = self.student_model(data)
# Compute distillation loss
loss = self.distillation_loss(student_logits, teacher_logits, target)
loss.backward()
optimizer.step()
return loss.item()
# Usage example
teacher = LargeModel() # Large teacher model
student = SmallModel() # Small student model
distiller = KnowledgeDistillation(teacher, student)
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)
for data, target in dataloader:
loss = distiller.train_step(data, target, optimizer)2. Model Pruning
python
import torch.nn.utils.prune as prune
class ModelPruner:
def __init__(self, model):
self.model = model
def structured_pruning(self, pruning_ratio=0.2):
"""Structured pruning"""
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
# Prune by channels
prune.ln_structured(
module, name='weight', amount=pruning_ratio,
n=2, dim=0 # Prune output channels
)
elif isinstance(module, nn.Linear):
# Prune by neurons
prune.ln_structured(
module, name='weight', amount=pruning_ratio,
n=2, dim=0
)
def unstructured_pruning(self, pruning_ratio=0.2):
"""Unstructured pruning"""
parameters_to_prune = []
for name, module in self.model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
parameters_to_prune.append((module, 'weight'))
# Global unstructured pruning
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=pruning_ratio,
)
def gradual_pruning(self, initial_sparsity=0.0, final_sparsity=0.8,
pruning_steps=100, pruning_frequency=10):
"""Gradual pruning"""
current_step = 0
for epoch in range(pruning_steps):
if epoch % pruning_frequency == 0:
# Compute current sparsity
current_sparsity = initial_sparsity + (
final_sparsity - initial_sparsity
) * (current_step / pruning_steps)
# Apply pruning
self.unstructured_pruning(current_sparsity)
current_step += 1
def remove_pruning(self):
"""Remove pruning reparameterization"""
for name, module in self.model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
try:
prune.remove(module, 'weight')
except:
pass
def calculate_sparsity(self):
"""Calculate model sparsity"""
total_params = 0
zero_params = 0
for param in self.model.parameters():
total_params += param.numel()
zero_params += (param == 0).sum().item()
sparsity = zero_params / total_params
return sparsity
# Usage example
pruner = ModelPruner(model)
# Apply pruning
pruner.unstructured_pruning(pruning_ratio=0.3)
# Calculate sparsity
sparsity = pruner.calculate_sparsity()
print(f"Model sparsity: {sparsity:.2%}")
# Fine-tune pruned model
for epoch in range(fine_tune_epochs):
# Training loop
pass
# Remove pruning reparameterization
pruner.remove_pruning()3. Quantization
python
import torch.quantization as quantization
class ModelQuantizer:
def __init__(self, model):
self.model = model
def post_training_quantization(self, calibration_loader):
"""Post-training quantization"""
# Set quantization configuration
self.model.qconfig = quantization.get_default_qconfig('fbgemm')
# Prepare for quantization
model_prepared = quantization.prepare(self.model, inplace=False)
# Calibrate
model_prepared.eval()
with torch.no_grad():
for data, _ in calibration_loader:
model_prepared(data)
# Convert to quantized model
model_quantized = quantization.convert(model_prepared, inplace=False)
return model_quantized
def quantization_aware_training(self, train_loader, num_epochs=5):
"""Quantization-aware training"""
# Set QAT configuration
self.model.qconfig = quantization.get_default_qat_qconfig('fbgemm')
# Prepare QAT
model_prepared = quantization.prepare_qat(self.model, inplace=False)
# QAT training
optimizer = torch.optim.Adam(model_prepared.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
model_prepared.train()
for data, target in train_loader:
optimizer.zero_grad()
output = model_prepared(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Convert to quantized model
model_prepared.eval()
model_quantized = quantization.convert(model_prepared, inplace=False)
return model_quantized
def dynamic_quantization(self):
"""Dynamic quantization"""
model_quantized = quantization.quantize_dynamic(
self.model, {nn.Linear}, dtype=torch.qint8
)
return model_quantized
# Usage example
quantizer = ModelQuantizer(model)
# Dynamic quantization (simplest)
quantized_model = quantizer.dynamic_quantization()
# Post-training quantization
# quantized_model = quantizer.post_training_quantization(calibration_loader)
# Quantization-aware training
# quantized_model = quantizer.quantization_aware_training(train_loader)
# Compare model sizes
def get_model_size(model):
torch.save(model.state_dict(), "temp.p")
size = os.path.getsize("temp.p")
os.remove("temp.p")
return size
original_size = get_model_size(model)
quantized_size = get_model_size(quantized_model)
print(f"Original model size: {original_size / 1024 / 1024:.2f} MB")
print(f"Quantized model size: {quantized_size / 1024 / 1024:.2f} MB")
print(f"Compression ratio: {original_size / quantized_size:.2f}x")Inference Optimization
1. TorchScript Optimization
python
# Optimize TorchScript model
def optimize_torchscript_model(model, example_input):
"""Optimize TorchScript model"""
model.eval()
# Trace model
traced_model = torch.jit.trace(model, example_input)
# Optimize
optimized_model = torch.jit.optimize_for_inference(traced_model)
# Freeze model
frozen_model = torch.jit.freeze(optimized_model)
return frozen_model
# Usage example
example_input = torch.randn(1, 3, 224, 224)
optimized_model = optimize_torchscript_model(model, example_input)
# Save optimized model
optimized_model.save('optimized_model.pt')2. Batch Inference Optimization
python
class BatchInferenceOptimizer:
def __init__(self, model, max_batch_size=32, timeout=0.1):
self.model = model
self.max_batch_size = max_batch_size
self.timeout = timeout
self.batch_queue = []
async def predict(self, input_data):
"""Asynchronous batch inference"""
import asyncio
from concurrent.futures import Future
future = Future()
self.batch_queue.append((input_data, future))
# Check if batch needs processing
if len(self.batch_queue) >= self.max_batch_size:
await self._process_batch()
else:
# Set timeout handling
asyncio.create_task(self._timeout_handler())
return await asyncio.wrap_future(future)
async def _process_batch(self):
"""Process batch data"""
if not self.batch_queue:
return
# Collect batch data
batch_data = []
futures = []
for data, future in self.batch_queue:
batch_data.append(data)
futures.append(future)
self.batch_queue.clear()
# Batch inference
try:
batch_input = torch.stack(batch_data)
with torch.no_grad():
batch_output = self.model(batch_input)
# Distribute results
for i, future in enumerate(futures):
future.set_result(batch_output[i])
except Exception as e:
for future in futures:
future.set_exception(e)
async def _timeout_handler(self):
"""Timeout handling"""
import asyncio
await asyncio.sleep(self.timeout)
if self.batch_queue:
await self._process_batch()Performance Monitoring
1. Profiler
python
import torch.profiler as profiler
def profile_model(model, input_data, num_steps=100):
"""Profile model performance"""
model.eval()
with profiler.profile(
activities=[
profiler.ProfilerActivity.CPU,
profiler.ProfilerActivity.CUDA,
],
schedule=profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
on_trace_ready=profiler.tensorboard_trace_handler('./log/profiler'),
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for step in range(num_steps):
with torch.no_grad():
output = model(input_data)
prof.step()
# Print performance report
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
return prof
# Usage example
input_data = torch.randn(32, 3, 224, 224).cuda()
prof = profile_model(model.cuda(), input_data)2. Memory Analysis
python
def analyze_memory_usage(model, input_data):
"""Analyze memory usage"""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Record initial memory
initial_memory = torch.cuda.memory_allocated()
# Forward pass
model.eval()
with torch.no_grad():
output = model(input_data)
# Record peak memory
peak_memory = torch.cuda.max_memory_allocated()
final_memory = torch.cuda.memory_allocated()
print(f"Initial memory: {initial_memory / 1024**2:.2f} MB")
print(f"Peak memory: {peak_memory / 1024**2:.2f} MB")
print(f"Final memory: {final_memory / 1024**2:.2f} MB")
print(f"Memory growth: {(final_memory - initial_memory) / 1024**2:.2f} MB")
return {
'initial': initial_memory,
'peak': peak_memory,
'final': final_memory
}
# Usage example
memory_stats = analyze_memory_usage(model.cuda(), input_data.cuda())Summary
PyTorch model optimization covers all aspects of training and inference:
- Training Optimization: Mixed precision, gradient accumulation, learning rate scheduling
- Memory Optimization: Gradient checkpointing, memory mapping, dynamic batching
- Compute Optimization: Model compilation, operator fusion, parallel computing
- Model Compression: Knowledge distillation, pruning, quantization
- Inference Optimization: TorchScript, batch inference
- Performance Monitoring: Profiling, memory analysis
Mastering these optimization techniques will significantly improve the efficiency and performance of your deep learning projects!