#PyTorch Model Optimization
#Optimization Overview
Model optimization is a critical aspect of deep learning projects, involving training efficiency, inference speed, memory usage, and model size. This chapter will introduce various optimization techniques in PyTorch.
#Training Optimization
#1. Mixed Precision Training
import torch
from torch.cuda.amp import GradScaler, autocast
import torch.nn as nn
class MixedPrecisionTrainer:
def __init__(self, model, optimizer, criterion):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.scaler = GradScaler()
def train_step(self, data, target):
"""Mixed precision training step"""
self.optimizer.zero_grad()
# Use autocast for forward pass
with autocast():
output = self.model(data)
loss = self.criterion(output, target)
# Scale loss and backpropagate
self.scaler.scale(loss).backward()
# Gradient clipping
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Update parameters
self.scaler.step(self.optimizer)
self.scaler.update()
return loss.item()
# Usage example
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
trainer = MixedPrecisionTrainer(model, optimizer, criterion)
# Training loop
for data, target in dataloader:
loss = trainer.train_step(data, target)#2. Gradient Accumulation
class GradientAccumulator:
def __init__(self, model, optimizer, criterion, accumulation_steps=4):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.accumulation_steps = accumulation_steps
self.step_count = 0
def accumulate_step(self, data, target):
"""Gradient accumulation step"""
# Forward pass
output = self.model(data)
loss = self.criterion(output, target)
# Scale loss
loss = loss / self.accumulation_steps
# Backward pass
loss.backward()
self.step_count += 1
# Update parameters every accumulation_steps steps
if self.step_count % self.accumulation_steps == 0:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Update parameters
self.optimizer.step()
self.optimizer.zero_grad()
return loss.item() * self.accumulation_steps
# Usage example
accumulator = GradientAccumulator(model, optimizer, criterion, accumulation_steps=8)
for data, target in dataloader:
loss = accumulator.accumulate_step(data, target)#3. Learning Rate Scheduling Optimization
import torch.optim.lr_scheduler as lr_scheduler
import math
class CosineAnnealingWarmRestarts(lr_scheduler._LRScheduler):
"""Cosine annealing scheduler with warmup"""
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, warmup_steps=0, last_epoch=-1):
self.T_0 = T_0
self.T_mult = T_mult
self.eta_min = eta_min
self.warmup_steps = warmup_steps
super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
# Warmup phase
return [base_lr * (self.last_epoch + 1) / self.warmup_steps
for base_lr in self.base_lrs]
else:
# Cosine annealing phase
adjusted_epoch = self.last_epoch - self.warmup_steps
T_cur = adjusted_epoch % self.T_0
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * T_cur / self.T_0)) / 2
for base_lr in self.base_lrs]
# Usage example
scheduler = CosineAnnealingWarmRestarts(
optimizer, T_0=50, T_mult=2, eta_min=1e-6, warmup_steps=10
)
# In training loop
for epoch in range(num_epochs):
for data, target in dataloader:
# Training step
pass
scheduler.step()#Memory Optimization
#1. Gradient Checkpointing
import torch.utils.checkpoint as checkpoint
class CheckpointedModel(nn.Module):
def __init__(self, layers):
super(CheckpointedModel, self).__init__()
self.layers = nn.ModuleList(layers)
def forward(self, x):
# Use gradient checkpointing to save memory
for layer in self.layers:
x = checkpoint.checkpoint(layer, x)
return x
# Or use decorator
def checkpointed_forward(module, input):
return checkpoint.checkpoint(module, input)
# In large models
class LargeTransformer(nn.Module):
def __init__(self, config):
super(LargeTransformer, self).__init__()
self.layers = nn.ModuleList([
TransformerLayer(config) for _ in range(config.num_layers)
])
def forward(self, x):
for layer in self.layers:
# Use checkpointing to save memory
x = checkpoint.checkpoint(layer, x)
return x#2. Memory-Mapped Datasets
import mmap
import numpy as np
from torch.utils.data import Dataset
class MemoryMappedDataset(Dataset):
def __init__(self, data_file, index_file):
# Use memory mapping to read large files
self.data_file = open(data_file, 'rb')
self.data_mmap = mmap.mmap(self.data_file.fileno(), 0, access=mmap.ACCESS_READ)
# Load indices
self.indices = np.load(index_file)
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
offset, size = self.indices[idx]
# Read data from memory map
self.data_mmap.seek(offset)
data_bytes = self.data_mmap.read(size)
# Parse data
data = self._parse_data(data_bytes)
return data
def _parse_data(self, data_bytes):
# Implement specific data parsing logic
pass
def __del__(self):
if hasattr(self, 'data_mmap'):
self.data_mmap.close()
if hasattr(self, 'data_file'):
self.data_file.close()#3. Dynamic Batch Size
class DynamicBatchSampler:
def __init__(self, dataset, max_tokens=4096, max_batch_size=32):
self.dataset = dataset
self.max_tokens = max_tokens
self.max_batch_size = max_batch_size
def __iter__(self):
batch = []
current_tokens = 0
for idx in range(len(self.dataset)):
sample_length = len(self.dataset[idx])
# Check if limits are exceeded
if (current_tokens + sample_length > self.max_tokens or
len(batch) >= self.max_batch_size) and batch:
yield batch
batch = []
current_tokens = 0
batch.append(idx)
current_tokens += sample_length
if batch:
yield batch
# Use dynamic batch sampler
sampler = DynamicBatchSampler(dataset, max_tokens=4096)
dataloader = DataLoader(dataset, batch_sampler=sampler)#Compute Optimization
#1. Model Compilation (PyTorch 2.0+)
import torch._dynamo as dynamo
# Compile model for better performance
@torch.compile
class OptimizedModel(nn.Module):
def __init__(self, config):
super(OptimizedModel, self).__init__()
self.layers = nn.Sequential(
nn.Linear(config.input_size, config.hidden_size),
nn.ReLU(),
nn.Linear(config.hidden_size, config.output_size)
)
def forward(self, x):
return self.layers(x)
# Or compile existing model
model = MyModel()
compiled_model = torch.compile(model, mode='max-autotune')
# Different compilation modes
# 'default': Balance compile time and runtime performance
# 'reduce-overhead': Reduce Python overhead
# 'max-autotune': Maximize performance optimization#2. Operator Fusion
# Manually fuse common operations
class FusedLinearReLU(nn.Module):
def __init__(self, in_features, out_features):
super(FusedLinearReLU, self).__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, x):
# Fuse linear transformation and ReLU activation
return torch.relu(self.linear(x))
# Use TorchScript for automatic fusion
class ModelForFusion(nn.Module):
def __init__(self):
super(ModelForFusion, self).__init__()
self.conv = nn.Conv2d(3, 64, 3, padding=1)
self.bn = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
# Script model to enable fusion
model = ModelForFusion()
scripted_model = torch.jit.script(model)
# Freeze model for inference optimization
scripted_model.eval()
frozen_model = torch.jit.freeze(scripted_model)#3. Parallel Computing Optimization
# Data parallel
model = nn.DataParallel(model)
# Distributed data parallel
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_distributed():
dist.init_process_group(backend='nccl')
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
def cleanup_distributed():
dist.destroy_process_group()
# Use DDP
model = DDP(model, device_ids=[local_rank])
# Model parallel (for large models)
class ModelParallelNet(nn.Module):
def __init__(self):
super(ModelParallelNet, self).__init__()
self.layer1 = nn.Linear(1000, 1000).to('cuda:0')
self.layer2 = nn.Linear(1000, 1000).to('cuda:1')
self.layer3 = nn.Linear(1000, 10).to('cuda:1')
def forward(self, x):
x = x.to('cuda:0')
x = self.layer1(x)
x = x.to('cuda:1')
x = self.layer2(x)
x = self.layer3(x)
return x#Model Compression
#1. Knowledge Distillation
class KnowledgeDistillation:
def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7):
self.teacher_model = teacher_model
self.student_model = student_model
self.temperature = temperature
self.alpha = alpha
# Freeze teacher model
for param in self.teacher_model.parameters():
param.requires_grad = False
def distillation_loss(self, student_logits, teacher_logits, true_labels):
"""Compute distillation loss"""
# Soft label loss
soft_loss = nn.KLDivLoss(reduction='batchmean')(
F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1)
) * (self.temperature ** 2)
# Hard label loss
hard_loss = F.cross_entropy(student_logits, true_labels)
# Combined loss
total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
return total_loss
def train_step(self, data, target, optimizer):
"""Distillation training step"""
self.teacher_model.eval()
self.student_model.train()
optimizer.zero_grad()
# Teacher model prediction
with torch.no_grad():
teacher_logits = self.teacher_model(data)
# Student model prediction
student_logits = self.student_model(data)
# Compute distillation loss
loss = self.distillation_loss(student_logits, teacher_logits, target)
loss.backward()
optimizer.step()
return loss.item()
# Usage example
teacher = LargeModel() # Large teacher model
student = SmallModel() # Small student model
distiller = KnowledgeDistillation(teacher, student)
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)
for data, target in dataloader:
loss = distiller.train_step(data, target, optimizer)#2. Model Pruning
import torch.nn.utils.prune as prune
class ModelPruner:
def __init__(self, model):
self.model = model
def structured_pruning(self, pruning_ratio=0.2):
"""Structured pruning"""
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
# Prune by channels
prune.ln_structured(
module, name='weight', amount=pruning_ratio,
n=2, dim=0 # Prune output channels
)
elif isinstance(module, nn.Linear):
# Prune by neurons
prune.ln_structured(
module, name='weight', amount=pruning_ratio,
n=2, dim=0
)
def unstructured_pruning(self, pruning_ratio=0.2):
"""Unstructured pruning"""
parameters_to_prune = []
for name, module in self.model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
parameters_to_prune.append((module, 'weight'))
# Global unstructured pruning
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=pruning_ratio,
)
def gradual_pruning(self, initial_sparsity=0.0, final_sparsity=0.8,
pruning_steps=100, pruning_frequency=10):
"""Gradual pruning"""
current_step = 0
for epoch in range(pruning_steps):
if epoch % pruning_frequency == 0:
# Compute current sparsity
current_sparsity = initial_sparsity + (
final_sparsity - initial_sparsity
) * (current_step / pruning_steps)
# Apply pruning
self.unstructured_pruning(current_sparsity)
current_step += 1
def remove_pruning(self):
"""Remove pruning reparameterization"""
for name, module in self.model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
try:
prune.remove(module, 'weight')
except:
pass
def calculate_sparsity(self):
"""Calculate model sparsity"""
total_params = 0
zero_params = 0
for param in self.model.parameters():
total_params += param.numel()
zero_params += (param == 0).sum().item()
sparsity = zero_params / total_params
return sparsity
# Usage example
pruner = ModelPruner(model)
# Apply pruning
pruner.unstructured_pruning(pruning_ratio=0.3)
# Calculate sparsity
sparsity = pruner.calculate_sparsity()
print(f"Model sparsity: {sparsity:.2%}")
# Fine-tune pruned model
for epoch in range(fine_tune_epochs):
# Training loop
pass
# Remove pruning reparameterization
pruner.remove_pruning()#3. Quantization
import torch.quantization as quantization
class ModelQuantizer:
def __init__(self, model):
self.model = model
def post_training_quantization(self, calibration_loader):
"""Post-training quantization"""
# Set quantization configuration
self.model.qconfig = quantization.get_default_qconfig('fbgemm')
# Prepare for quantization
model_prepared = quantization.prepare(self.model, inplace=False)
# Calibrate
model_prepared.eval()
with torch.no_grad():
for data, _ in calibration_loader:
model_prepared(data)
# Convert to quantized model
model_quantized = quantization.convert(model_prepared, inplace=False)
return model_quantized
def quantization_aware_training(self, train_loader, num_epochs=5):
"""Quantization-aware training"""
# Set QAT configuration
self.model.qconfig = quantization.get_default_qat_qconfig('fbgemm')
# Prepare QAT
model_prepared = quantization.prepare_qat(self.model, inplace=False)
# QAT training
optimizer = torch.optim.Adam(model_prepared.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
model_prepared.train()
for data, target in train_loader:
optimizer.zero_grad()
output = model_prepared(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Convert to quantized model
model_prepared.eval()
model_quantized = quantization.convert(model_prepared, inplace=False)
return model_quantized
def dynamic_quantization(self):
"""Dynamic quantization"""
model_quantized = quantization.quantize_dynamic(
self.model, {nn.Linear}, dtype=torch.qint8
)
return model_quantized
# Usage example
quantizer = ModelQuantizer(model)
# Dynamic quantization (simplest)
quantized_model = quantizer.dynamic_quantization()
# Post-training quantization
# quantized_model = quantizer.post_training_quantization(calibration_loader)
# Quantization-aware training
# quantized_model = quantizer.quantization_aware_training(train_loader)
# Compare model sizes
def get_model_size(model):
torch.save(model.state_dict(), "temp.p")
size = os.path.getsize("temp.p")
os.remove("temp.p")
return size
original_size = get_model_size(model)
quantized_size = get_model_size(quantized_model)
print(f"Original model size: {original_size / 1024 / 1024:.2f} MB")
print(f"Quantized model size: {quantized_size / 1024 / 1024:.2f} MB")
print(f"Compression ratio: {original_size / quantized_size:.2f}x")#Inference Optimization
#1. TorchScript Optimization
# Optimize TorchScript model
def optimize_torchscript_model(model, example_input):
"""Optimize TorchScript model"""
model.eval()
# Trace model
traced_model = torch.jit.trace(model, example_input)
# Optimize
optimized_model = torch.jit.optimize_for_inference(traced_model)
# Freeze model
frozen_model = torch.jit.freeze(optimized_model)
return frozen_model
# Usage example
example_input = torch.randn(1, 3, 224, 224)
optimized_model = optimize_torchscript_model(model, example_input)
# Save optimized model
optimized_model.save('optimized_model.pt')#2. Batch Inference Optimization
class BatchInferenceOptimizer:
def __init__(self, model, max_batch_size=32, timeout=0.1):
self.model = model
self.max_batch_size = max_batch_size
self.timeout = timeout
self.batch_queue = []
async def predict(self, input_data):
"""Asynchronous batch inference"""
import asyncio
from concurrent.futures import Future
future = Future()
self.batch_queue.append((input_data, future))
# Check if batch needs processing
if len(self.batch_queue) >= self.max_batch_size:
await self._process_batch()
else:
# Set timeout handling
asyncio.create_task(self._timeout_handler())
return await asyncio.wrap_future(future)
async def _process_batch(self):
"""Process batch data"""
if not self.batch_queue:
return
# Collect batch data
batch_data = []
futures = []
for data, future in self.batch_queue:
batch_data.append(data)
futures.append(future)
self.batch_queue.clear()
# Batch inference
try:
batch_input = torch.stack(batch_data)
with torch.no_grad():
batch_output = self.model(batch_input)
# Distribute results
for i, future in enumerate(futures):
future.set_result(batch_output[i])
except Exception as e:
for future in futures:
future.set_exception(e)
async def _timeout_handler(self):
"""Timeout handling"""
import asyncio
await asyncio.sleep(self.timeout)
if self.batch_queue:
await self._process_batch()#Performance Monitoring
#1. Profiler
import torch.profiler as profiler
def profile_model(model, input_data, num_steps=100):
"""Profile model performance"""
model.eval()
with profiler.profile(
activities=[
profiler.ProfilerActivity.CPU,
profiler.ProfilerActivity.CUDA,
],
schedule=profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
on_trace_ready=profiler.tensorboard_trace_handler('./log/profiler'),
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for step in range(num_steps):
with torch.no_grad():
output = model(input_data)
prof.step()
# Print performance report
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
return prof
# Usage example
input_data = torch.randn(32, 3, 224, 224).cuda()
prof = profile_model(model.cuda(), input_data)#2. Memory Analysis
def analyze_memory_usage(model, input_data):
"""Analyze memory usage"""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Record initial memory
initial_memory = torch.cuda.memory_allocated()
# Forward pass
model.eval()
with torch.no_grad():
output = model(input_data)
# Record peak memory
peak_memory = torch.cuda.max_memory_allocated()
final_memory = torch.cuda.memory_allocated()
print(f"Initial memory: {initial_memory / 1024**2:.2f} MB")
print(f"Peak memory: {peak_memory / 1024**2:.2f} MB")
print(f"Final memory: {final_memory / 1024**2:.2f} MB")
print(f"Memory growth: {(final_memory - initial_memory) / 1024**2:.2f} MB")
return {
'initial': initial_memory,
'peak': peak_memory,
'final': final_memory
}
# Usage example
memory_stats = analyze_memory_usage(model.cuda(), input_data.cuda())#Summary
PyTorch model optimization covers all aspects of training and inference:
- Training Optimization: Mixed precision, gradient accumulation, learning rate scheduling
- Memory Optimization: Gradient checkpointing, memory mapping, dynamic batching
- Compute Optimization: Model compilation, operator fusion, parallel computing
- Model Compression: Knowledge distillation, pruning, quantization
- Inference Optimization: TorchScript, batch inference
- Performance Monitoring: Profiling, memory analysis
Mastering these optimization techniques will significantly improve the efficiency and performance of your deep learning projects!