PyTorch Best Practices
Code Organization and Project Structure
1. Recommended Project Structure
project/
├── data/ # Data directory
│ ├── raw/ # Raw data
│ ├── processed/ # Processed data
│ └── external/ # External data
├── models/ # Model definitions
│ ├── __init__.py
│ ├── base_model.py # Base model class
│ ├── resnet.py # Specific model implementation
│ └── transformer.py
├── src/ # Source code
│ ├── __init__.py
│ ├── data/ # Data processing
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ └── transforms.py
│ ├── training/ # Training related
│ │ ├── __init__.py
│ │ ├── trainer.py
│ │ └── losses.py
│ └── utils/ # Utility functions
│ ├── __init__.py
│ ├── metrics.py
│ └── visualization.py
├── configs/ # Configuration files
│ ├── base_config.yaml
│ └── experiment_configs/
├── experiments/ # Experiment records
├── notebooks/ # Jupyter notebooks
├── tests/ # Test code
├── requirements.txt # Dependencies
├── setup.py # Installation script
└── README.md # Project description2. Base Model Class Design
python
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
class BaseModel(nn.Module, ABC):
"""Abstract base model class"""
def __init__(self, config: Dict[str, Any]):
super(BaseModel, self).__init__()
self.config = config
self._build_model()
@abstractmethod
def _build_model(self):
"""Build model architecture"""
pass
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass"""
pass
def get_num_parameters(self) -> int:
"""Get model parameter count"""
return sum(p.numel() for p in self.parameters())
def get_num_trainable_parameters(self) -> int:
"""Get trainable parameter count"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def freeze_parameters(self, module_names: Optional[list] = None):
"""Freeze specified module parameters"""
if module_names is None:
# Freeze all parameters
for param in self.parameters():
param.requires_grad = False
else:
# Freeze specified modules
for name, module in self.named_modules():
if any(module_name in name for module_name in module_names):
for param in module.parameters():
param.requires_grad = False
def unfreeze_parameters(self, module_names: Optional[list] = None):
"""Unfreeze specified module parameters"""
if module_names is None:
# Unfreeze all parameters
for param in self.parameters():
param.requires_grad = True
else:
# Unfreeze specified modules
for name, module in self.named_modules():
if any(module_name in name for module_name in module_names):
for param in module.parameters():
param.requires_grad = True
def save_checkpoint(self, filepath: str, epoch: int, optimizer_state: Dict = None,
scheduler_state: Dict = None, **kwargs):
"""Save checkpoint"""
checkpoint = {
'epoch': epoch,
'model_state_dict': self.state_dict(),
'config': self.config,
'num_parameters': self.get_num_parameters(),
**kwargs
}
if optimizer_state:
checkpoint['optimizer_state_dict'] = optimizer_state
if scheduler_state:
checkpoint['scheduler_state_dict'] = scheduler_state
torch.save(checkpoint, filepath)
@classmethod
def load_checkpoint(cls, filepath: str, map_location: str = 'cpu'):
"""Load checkpoint"""
checkpoint = torch.load(filepath, map_location=map_location)
model = cls(checkpoint['config'])
model.load_state_dict(checkpoint['model_state_dict'])
return model, checkpoint
# Specific model implementation example
class ResNetClassifier(BaseModel):
def _build_model(self):
from torchvision.models import resnet18
self.backbone = resnet18(pretrained=self.config.get('pretrained', True))
self.backbone.fc = nn.Linear(
self.backbone.fc.in_features,
self.config['num_classes']
)
def forward(self, x):
return self.backbone(x)Data Processing Best Practices
1. Efficient Data Loading
python
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import Tuple, List, Optional
import multiprocessing as mp
class OptimizedDataset(Dataset):
"""Optimized dataset class"""
def __init__(self, data_path: str, transform=None, cache_size: int = 1000):
self.data_path = data_path
self.transform = transform
self.cache_size = cache_size
self.cache = {}
self.access_count = {}
# Preload index information
self._load_index()
def _load_index(self):
"""Load data index to avoid reading files every time"""
# Implement specific index loading logic
pass
def __getitem__(self, idx):
# Cache mechanism
if idx in self.cache:
self.access_count[idx] += 1
data = self.cache[idx]
else:
data = self._load_data(idx)
self._update_cache(idx, data)
if self.transform:
data = self.transform(data)
return data
def _load_data(self, idx):
"""Load single data sample"""
# Implement specific data loading logic
pass
def _update_cache(self, idx, data):
"""Update cache"""
if len(self.cache) >= self.cache_size:
# Remove least accessed item
lru_idx = min(self.access_count, key=self.access_count.get)
del self.cache[lru_idx]
del self.access_count[lru_idx]
self.cache[idx] = data
self.access_count[idx] = 1
def create_optimized_dataloader(dataset, batch_size: int, num_workers: Optional[int] = None,
pin_memory: bool = True, persistent_workers: bool = True):
"""Create optimized data loader"""
if num_workers is None:
num_workers = min(8, mp.cpu_count())
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory and torch.cuda.is_available(),
persistent_workers=persistent_workers and num_workers > 0,
prefetch_factor=2 if num_workers > 0 else 2,
drop_last=True # Keep batch size consistent
)2. Data Preprocessing Pipeline
python
import torchvision.transforms as transforms
from typing import Union, List
class DataPreprocessor:
"""Data preprocessing pipeline"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.train_transform = self._build_train_transform()
self.val_transform = self._build_val_transform()
def _build_train_transform(self):
"""Build data transforms for training"""
transforms_list = []
# Basic transforms
if self.config.get('resize'):
transforms_list.append(transforms.Resize(self.config['resize']))
# Data augmentation
if self.config.get('random_crop'):
transforms_list.append(
transforms.RandomCrop(
self.config['random_crop']['size'],
padding=self.config['random_crop'].get('padding', 4)
)
)
if self.config.get('random_horizontal_flip'):
transforms_list.append(
transforms.RandomHorizontalFlip(
p=self.config['random_horizontal_flip']
)
)
if self.config.get('color_jitter'):
transforms_list.append(
transforms.ColorJitter(**self.config['color_jitter'])
)
# Convert to tensor and normalize
transforms_list.extend([
transforms.ToTensor(),
transforms.Normalize(
mean=self.config['normalize']['mean'],
std=self.config['normalize']['std']
)
])
# Advanced augmentation
if self.config.get('random_erasing'):
transforms_list.append(
transforms.RandomErasing(**self.config['random_erasing'])
)
return transforms.Compose(transforms_list)
def _build_val_transform(self):
"""Build data transforms for validation"""
transforms_list = []
if self.config.get('resize'):
transforms_list.append(transforms.Resize(self.config['resize']))
if self.config.get('center_crop'):
transforms_list.append(transforms.CenterCrop(self.config['center_crop']))
transforms_list.extend([
transforms.ToTensor(),
transforms.Normalize(
mean=self.config['normalize']['mean'],
std=self.config['normalize']['std']
)
])
return transforms.Compose(transforms_list)Training Optimization Techniques
1. Mixed Precision Training
python
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F
class MixedPrecisionTrainer:
"""Mixed precision trainer"""
def __init__(self, model, optimizer, criterion, device):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.device = device
self.scaler = GradScaler()
def train_step(self, data, target):
"""Single training step"""
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
# Use autocast for forward pass
with autocast():
output = self.model(data)
loss = self.criterion(output, target)
# Scale loss and backpropagate
self.scaler.scale(loss).backward()
# Gradient clipping
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Update parameters
self.scaler.step(self.optimizer)
self.scaler.update()
return loss.item()
def validate_step(self, data, target):
"""Validation step"""
data, target = data.to(self.device), target.to(self.device)
with torch.no_grad(), autocast():
output = self.model(data)
loss = self.criterion(output, target)
return loss.item(), output2. Gradient Accumulation
python
class GradientAccumulationTrainer:
"""Gradient accumulation trainer"""
def __init__(self, model, optimizer, criterion, device, accumulation_steps=4):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.device = device
self.accumulation_steps = accumulation_steps
def train_epoch(self, dataloader):
"""Train one epoch"""
self.model.train()
total_loss = 0
self.optimizer.zero_grad()
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(self.device), target.to(self.device)
# Forward pass
output = self.model(data)
loss = self.criterion(output, target)
# Scale loss
loss = loss / self.accumulation_steps
# Backward pass
loss.backward()
total_loss += loss.item() * self.accumulation_steps
# Update parameters every accumulation_steps steps
if (batch_idx + 1) % self.accumulation_steps == 0:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Update parameters
self.optimizer.step()
self.optimizer.zero_grad()
return total_loss / len(dataloader)Model Optimization and Deployment
1. Model Quantization
python
import torch.quantization as quantization
def quantize_model(model, calibration_dataloader, device):
"""Quantize model"""
# Set quantization configuration
model.qconfig = quantization.get_default_qconfig('fbgemm')
# Prepare for quantization
model_prepared = quantization.prepare(model, inplace=False)
# Calibrate
model_prepared.eval()
with torch.no_grad():
for data, _ in calibration_dataloader:
data = data.to(device)
model_prepared(data)
# Convert to quantized model
model_quantized = quantization.convert(model_prepared, inplace=False)
return model_quantized
def compare_model_sizes(model_fp32, model_quantized):
"""Compare model sizes"""
def get_model_size(model):
torch.save(model.state_dict(), "temp.p")
size = os.path.getsize("temp.p")
os.remove("temp.p")
return size
fp32_size = get_model_size(model_fp32)
quantized_size = get_model_size(model_quantized)
print(f"FP32 model size: {fp32_size / 1024 / 1024:.2f} MB")
print(f"Quantized model size: {quantized_size / 1024 / 1024:.2f} MB")
print(f"Compression ratio: {fp32_size / quantized_size:.2f}x")2. Model Pruning
python
import torch.nn.utils.prune as prune
def prune_model(model, pruning_ratio=0.2):
"""Prune model"""
# Collect all convolution and linear layers
modules_to_prune = []
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
modules_to_prune.append((module, 'weight'))
# Global unstructured pruning
prune.global_unstructured(
modules_to_prune,
pruning_method=prune.L1Unstructured,
amount=pruning_ratio,
)
# Remove pruning reparameterization
for module, param_name in modules_to_prune:
prune.remove(module, param_name)
return model
def calculate_sparsity(model):
"""Calculate model sparsity"""
total_params = 0
zero_params = 0
for param in model.parameters():
total_params += param.numel()
zero_params += (param == 0).sum().item()
sparsity = zero_params / total_params
print(f"Model sparsity: {sparsity:.2%}")
return sparsityDebugging and Monitoring
1. Training Monitoring
python
import wandb
from torch.utils.tensorboard import SummaryWriter
import time
class TrainingMonitor:
"""Training monitor"""
def __init__(self, project_name, experiment_name, config):
self.config = config
# Initialize wandb
if config.get('use_wandb', False):
wandb.init(project=project_name, name=experiment_name, config=config)
# Initialize tensorboard
if config.get('use_tensorboard', False):
self.writer = SummaryWriter(f'runs/{experiment_name}')
self.metrics = {}
self.start_time = time.time()
def log_metrics(self, metrics, step, prefix=''):
"""Log metrics"""
for key, value in metrics.items():
metric_name = f"{prefix}/{key}" if prefix else key
# Log to wandb
if hasattr(self, 'wandb') and wandb.run:
wandb.log({metric_name: value}, step=step)
# Log to tensorboard
if hasattr(self, 'writer'):
self.writer.add_scalar(metric_name, value, step)
def log_model_graph(self, model, input_sample):
"""Log model graph"""
if hasattr(self, 'writer'):
self.writer.add_graph(model, input_sample)
def log_gradients(self, model, step):
"""Log gradient information"""
if hasattr(self, 'writer'):
for name, param in model.named_parameters():
if param.grad is not None:
self.writer.add_histogram(f'gradients/{name}', param.grad, step)
self.writer.add_scalar(f'gradient_norms/{name}',
param.grad.norm().item(), step)
def log_learning_rate(self, optimizer, step):
"""Log learning rate"""
for i, param_group in enumerate(optimizer.param_groups):
lr = param_group['lr']
if hasattr(self, 'writer'):
self.writer.add_scalar(f'learning_rate/group_{i}', lr, step)
def close(self):
"""Close monitor"""
if hasattr(self, 'writer'):
self.writer.close()
if wandb.run:
wandb.finish()2. Model Diagnostics
python
class ModelDiagnostics:
"""Model diagnostic tools"""
@staticmethod
def check_gradients(model, threshold=1e-7):
"""Check gradients"""
gradient_issues = []
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
# Check gradient explosion
if grad_norm > 100:
gradient_issues.append(f"Gradient explosion: {name}, norm={grad_norm:.2f}")
# Check gradient vanishing
elif grad_norm < threshold:
gradient_issues.append(f"Gradient vanishing: {name}, norm={grad_norm:.2e}")
# Check NaN or Inf
if torch.isnan(param.grad).any():
gradient_issues.append(f"NaN gradient: {name}")
if torch.isinf(param.grad).any():
gradient_issues.append(f"Inf gradient: {name}")
return gradient_issues
@staticmethod
def check_weights(model):
"""Check weights"""
weight_issues = []
for name, param in model.named_parameters():
# Check NaN or Inf
if torch.isnan(param).any():
weight_issues.append(f"NaN weight: {name}")
if torch.isinf(param).any():
weight_issues.append(f"Inf weight: {name}")
# Check weight distribution
weight_std = param.std().item()
if weight_std < 1e-6:
weight_issues.append(f"Too small weight variance: {name}, std={weight_std:.2e}")
elif weight_std > 10:
weight_issues.append(f"Too large weight variance: {name}, std={weight_std:.2f}")
return weight_issues
@staticmethod
def analyze_activations(model, input_data):
"""Analyze activation values"""
activations = {}
def hook_fn(name):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
activations[name] = {
'mean': output.mean().item(),
'std': output.std().item(),
'min': output.min().item(),
'max': output.max().item(),
'has_nan': torch.isnan(output).any().item(),
'has_inf': torch.isinf(output).any().item()
}
return hook
# Register hooks
hooks = []
for name, module in model.named_modules():
if len(list(module.children())) == 0: # Leaf nodes
hook = module.register_forward_hook(hook_fn(name))
hooks.append(hook)
# Forward pass
model.eval()
with torch.no_grad():
_ = model(input_data)
# Remove hooks
for hook in hooks:
hook.remove()
return activationsPerformance Optimization
1. Memory Optimization
python
import gc
import torch
class MemoryOptimizer:
"""Memory optimization tools"""
@staticmethod
def clear_cache():
"""Clear GPU cache"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
@staticmethod
def get_memory_usage():
"""Get memory usage"""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3 # GB
cached = torch.cuda.memory_reserved() / 1024**3 # GB
return f"GPU memory: allocated {allocated:.2f}GB, cached {cached:.2f}GB"
else:
import psutil
memory = psutil.virtual_memory()
return f"CPU memory: used {memory.percent:.1f}%"
@staticmethod
def optimize_dataloader_memory(dataloader):
"""Optimize data loader memory usage"""
# Reduce prefetch factor
dataloader.prefetch_factor = 1
# Use fewer workers
if dataloader.num_workers > 4:
dataloader.num_workers = 4
return dataloader
@staticmethod
def use_gradient_checkpointing(model):
"""Use gradient checkpointing"""
from torch.utils.checkpoint import checkpoint
# Add checkpoint functionality to model
def checkpoint_forward(module, input):
return checkpoint(module, input)
# Apply to specified layers
for name, module in model.named_modules():
if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
module.forward = lambda x: checkpoint_forward(module, x)
return model2. Computation Optimization
python
class ComputationOptimizer:
"""Computation optimization tools"""
@staticmethod
def optimize_model_for_inference(model):
"""Optimize model for inference"""
# Fuse BatchNorm
model = torch.jit.script(model)
# Freeze model
model.eval()
for param in model.parameters():
param.requires_grad = False
return model
@staticmethod
def enable_cudnn_benchmark():
"""Enable cuDNN benchmark"""
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
@staticmethod
def compile_model(model, mode='default'):
"""Compile model (PyTorch 2.0+)"""
if hasattr(torch, 'compile'):
return torch.compile(model, mode=mode)
return modelSummary
PyTorch best practices cover all aspects of deep learning projects:
- Code Organization: Clear project structure and modular design
- Data Processing: Efficient data loading and preprocessing pipelines
- Training Optimization: Advanced techniques like mixed precision and gradient accumulation
- Model Optimization: Model compression methods like quantization and pruning
- Debugging Monitoring: Complete training monitoring and model diagnostic tools
- Performance Optimization: Effective use of memory and computing resources
Following these best practices will help you build more efficient and reliable deep learning systems!