#PyTorch Best Practices
#Code Organization and Project Structure
#1. Recommended Project Structure
project/
├── data/ # Data directory
│ ├── raw/ # Raw data
│ ├── processed/ # Processed data
│ └── external/ # External data
├── models/ # Model definitions
│ ├── __init__.py
│ ├── base_model.py # Base model class
│ ├── resnet.py # Specific model implementation
│ └── transformer.py
├── src/ # Source code
│ ├── __init__.py
│ ├── data/ # Data processing
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ └── transforms.py
│ ├── training/ # Training related
│ │ ├── __init__.py
│ │ ├── trainer.py
│ │ └── losses.py
│ └── utils/ # Utility functions
│ ├── __init__.py
│ ├── metrics.py
│ └── visualization.py
├── configs/ # Configuration files
│ ├── base_config.yaml
│ └── experiment_configs/
├── experiments/ # Experiment records
├── notebooks/ # Jupyter notebooks
├── tests/ # Test code
├── requirements.txt # Dependencies
├── setup.py # Installation script
└── README.md # Project description#2. Base Model Class Design
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
class BaseModel(nn.Module, ABC):
"""Abstract base model class"""
def __init__(self, config: Dict[str, Any]):
super(BaseModel, self).__init__()
self.config = config
self._build_model()
@abstractmethod
def _build_model(self):
"""Build model architecture"""
pass
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass"""
pass
def get_num_parameters(self) -> int:
"""Get model parameter count"""
return sum(p.numel() for p in self.parameters())
def get_num_trainable_parameters(self) -> int:
"""Get trainable parameter count"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def freeze_parameters(self, module_names: Optional[list] = None):
"""Freeze specified module parameters"""
if module_names is None:
# Freeze all parameters
for param in self.parameters():
param.requires_grad = False
else:
# Freeze specified modules
for name, module in self.named_modules():
if any(module_name in name for module_name in module_names):
for param in module.parameters():
param.requires_grad = False
def unfreeze_parameters(self, module_names: Optional[list] = None):
"""Unfreeze specified module parameters"""
if module_names is None:
# Unfreeze all parameters
for param in self.parameters():
param.requires_grad = True
else:
# Unfreeze specified modules
for name, module in self.named_modules():
if any(module_name in name for module_name in module_names):
for param in module.parameters():
param.requires_grad = True
def save_checkpoint(self, filepath: str, epoch: int, optimizer_state: Dict = None,
scheduler_state: Dict = None, **kwargs):
"""Save checkpoint"""
checkpoint = {
'epoch': epoch,
'model_state_dict': self.state_dict(),
'config': self.config,
'num_parameters': self.get_num_parameters(),
**kwargs
}
if optimizer_state:
checkpoint['optimizer_state_dict'] = optimizer_state
if scheduler_state:
checkpoint['scheduler_state_dict'] = scheduler_state
torch.save(checkpoint, filepath)
@classmethod
def load_checkpoint(cls, filepath: str, map_location: str = 'cpu'):
"""Load checkpoint"""
checkpoint = torch.load(filepath, map_location=map_location)
model = cls(checkpoint['config'])
model.load_state_dict(checkpoint['model_state_dict'])
return model, checkpoint
# Specific model implementation example
class ResNetClassifier(BaseModel):
def _build_model(self):
from torchvision.models import resnet18
self.backbone = resnet18(pretrained=self.config.get('pretrained', True))
self.backbone.fc = nn.Linear(
self.backbone.fc.in_features,
self.config['num_classes']
)
def forward(self, x):
return self.backbone(x)#Data Processing Best Practices
#1. Efficient Data Loading
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import Tuple, List, Optional
import multiprocessing as mp
class OptimizedDataset(Dataset):
"""Optimized dataset class"""
def __init__(self, data_path: str, transform=None, cache_size: int = 1000):
self.data_path = data_path
self.transform = transform
self.cache_size = cache_size
self.cache = {}
self.access_count = {}
# Preload index information
self._load_index()
def _load_index(self):
"""Load data index to avoid reading files every time"""
# Implement specific index loading logic
pass
def __getitem__(self, idx):
# Cache mechanism
if idx in self.cache:
self.access_count[idx] += 1
data = self.cache[idx]
else:
data = self._load_data(idx)
self._update_cache(idx, data)
if self.transform:
data = self.transform(data)
return data
def _load_data(self, idx):
"""Load single data sample"""
# Implement specific data loading logic
pass
def _update_cache(self, idx, data):
"""Update cache"""
if len(self.cache) >= self.cache_size:
# Remove least accessed item
lru_idx = min(self.access_count, key=self.access_count.get)
del self.cache[lru_idx]
del self.access_count[lru_idx]
self.cache[idx] = data
self.access_count[idx] = 1
def create_optimized_dataloader(dataset, batch_size: int, num_workers: Optional[int] = None,
pin_memory: bool = True, persistent_workers: bool = True):
"""Create optimized data loader"""
if num_workers is None:
num_workers = min(8, mp.cpu_count())
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory and torch.cuda.is_available(),
persistent_workers=persistent_workers and num_workers > 0,
prefetch_factor=2 if num_workers > 0 else 2,
drop_last=True # Keep batch size consistent
)#2. Data Preprocessing Pipeline
import torchvision.transforms as transforms
from typing import Union, List
class DataPreprocessor:
"""Data preprocessing pipeline"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.train_transform = self._build_train_transform()
self.val_transform = self._build_val_transform()
def _build_train_transform(self):
"""Build data transforms for training"""
transforms_list = []
# Basic transforms
if self.config.get('resize'):
transforms_list.append(transforms.Resize(self.config['resize']))
# Data augmentation
if self.config.get('random_crop'):
transforms_list.append(
transforms.RandomCrop(
self.config['random_crop']['size'],
padding=self.config['random_crop'].get('padding', 4)
)
)
if self.config.get('random_horizontal_flip'):
transforms_list.append(
transforms.RandomHorizontalFlip(
p=self.config['random_horizontal_flip']
)
)
if self.config.get('color_jitter'):
transforms_list.append(
transforms.ColorJitter(**self.config['color_jitter'])
)
# Convert to tensor and normalize
transforms_list.extend([
transforms.ToTensor(),
transforms.Normalize(
mean=self.config['normalize']['mean'],
std=self.config['normalize']['std']
)
])
# Advanced augmentation
if self.config.get('random_erasing'):
transforms_list.append(
transforms.RandomErasing(**self.config['random_erasing'])
)
return transforms.Compose(transforms_list)
def _build_val_transform(self):
"""Build data transforms for validation"""
transforms_list = []
if self.config.get('resize'):
transforms_list.append(transforms.Resize(self.config['resize']))
if self.config.get('center_crop'):
transforms_list.append(transforms.CenterCrop(self.config['center_crop']))
transforms_list.extend([
transforms.ToTensor(),
transforms.Normalize(
mean=self.config['normalize']['mean'],
std=self.config['normalize']['std']
)
])
return transforms.Compose(transforms_list)#Training Optimization Techniques
#1. Mixed Precision Training
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F
class MixedPrecisionTrainer:
"""Mixed precision trainer"""
def __init__(self, model, optimizer, criterion, device):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.device = device
self.scaler = GradScaler()
def train_step(self, data, target):
"""Single training step"""
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
# Use autocast for forward pass
with autocast():
output = self.model(data)
loss = self.criterion(output, target)
# Scale loss and backpropagate
self.scaler.scale(loss).backward()
# Gradient clipping
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Update parameters
self.scaler.step(self.optimizer)
self.scaler.update()
return loss.item()
def validate_step(self, data, target):
"""Validation step"""
data, target = data.to(self.device), target.to(self.device)
with torch.no_grad(), autocast():
output = self.model(data)
loss = self.criterion(output, target)
return loss.item(), output#2. Gradient Accumulation
class GradientAccumulationTrainer:
"""Gradient accumulation trainer"""
def __init__(self, model, optimizer, criterion, device, accumulation_steps=4):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.device = device
self.accumulation_steps = accumulation_steps
def train_epoch(self, dataloader):
"""Train one epoch"""
self.model.train()
total_loss = 0
self.optimizer.zero_grad()
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(self.device), target.to(self.device)
# Forward pass
output = self.model(data)
loss = self.criterion(output, target)
# Scale loss
loss = loss / self.accumulation_steps
# Backward pass
loss.backward()
total_loss += loss.item() * self.accumulation_steps
# Update parameters every accumulation_steps steps
if (batch_idx + 1) % self.accumulation_steps == 0:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Update parameters
self.optimizer.step()
self.optimizer.zero_grad()
return total_loss / len(dataloader)#Model Optimization and Deployment
#1. Model Quantization
import torch.quantization as quantization
def quantize_model(model, calibration_dataloader, device):
"""Quantize model"""
# Set quantization configuration
model.qconfig = quantization.get_default_qconfig('fbgemm')
# Prepare for quantization
model_prepared = quantization.prepare(model, inplace=False)
# Calibrate
model_prepared.eval()
with torch.no_grad():
for data, _ in calibration_dataloader:
data = data.to(device)
model_prepared(data)
# Convert to quantized model
model_quantized = quantization.convert(model_prepared, inplace=False)
return model_quantized
def compare_model_sizes(model_fp32, model_quantized):
"""Compare model sizes"""
def get_model_size(model):
torch.save(model.state_dict(), "temp.p")
size = os.path.getsize("temp.p")
os.remove("temp.p")
return size
fp32_size = get_model_size(model_fp32)
quantized_size = get_model_size(model_quantized)
print(f"FP32 model size: {fp32_size / 1024 / 1024:.2f} MB")
print(f"Quantized model size: {quantized_size / 1024 / 1024:.2f} MB")
print(f"Compression ratio: {fp32_size / quantized_size:.2f}x")#2. Model Pruning
import torch.nn.utils.prune as prune
def prune_model(model, pruning_ratio=0.2):
"""Prune model"""
# Collect all convolution and linear layers
modules_to_prune = []
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
modules_to_prune.append((module, 'weight'))
# Global unstructured pruning
prune.global_unstructured(
modules_to_prune,
pruning_method=prune.L1Unstructured,
amount=pruning_ratio,
)
# Remove pruning reparameterization
for module, param_name in modules_to_prune:
prune.remove(module, param_name)
return model
def calculate_sparsity(model):
"""Calculate model sparsity"""
total_params = 0
zero_params = 0
for param in model.parameters():
total_params += param.numel()
zero_params += (param == 0).sum().item()
sparsity = zero_params / total_params
print(f"Model sparsity: {sparsity:.2%}")
return sparsity#Debugging and Monitoring
#1. Training Monitoring
import wandb
from torch.utils.tensorboard import SummaryWriter
import time
class TrainingMonitor:
"""Training monitor"""
def __init__(self, project_name, experiment_name, config):
self.config = config
# Initialize wandb
if config.get('use_wandb', False):
wandb.init(project=project_name, name=experiment_name, config=config)
# Initialize tensorboard
if config.get('use_tensorboard', False):
self.writer = SummaryWriter(f'runs/{experiment_name}')
self.metrics = {}
self.start_time = time.time()
def log_metrics(self, metrics, step, prefix=''):
"""Log metrics"""
for key, value in metrics.items():
metric_name = f"{prefix}/{key}" if prefix else key
# Log to wandb
if hasattr(self, 'wandb') and wandb.run:
wandb.log({metric_name: value}, step=step)
# Log to tensorboard
if hasattr(self, 'writer'):
self.writer.add_scalar(metric_name, value, step)
def log_model_graph(self, model, input_sample):
"""Log model graph"""
if hasattr(self, 'writer'):
self.writer.add_graph(model, input_sample)
def log_gradients(self, model, step):
"""Log gradient information"""
if hasattr(self, 'writer'):
for name, param in model.named_parameters():
if param.grad is not None:
self.writer.add_histogram(f'gradients/{name}', param.grad, step)
self.writer.add_scalar(f'gradient_norms/{name}',
param.grad.norm().item(), step)
def log_learning_rate(self, optimizer, step):
"""Log learning rate"""
for i, param_group in enumerate(optimizer.param_groups):
lr = param_group['lr']
if hasattr(self, 'writer'):
self.writer.add_scalar(f'learning_rate/group_{i}', lr, step)
def close(self):
"""Close monitor"""
if hasattr(self, 'writer'):
self.writer.close()
if wandb.run:
wandb.finish()#2. Model Diagnostics
class ModelDiagnostics:
"""Model diagnostic tools"""
@staticmethod
def check_gradients(model, threshold=1e-7):
"""Check gradients"""
gradient_issues = []
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
# Check gradient explosion
if grad_norm > 100:
gradient_issues.append(f"Gradient explosion: {name}, norm={grad_norm:.2f}")
# Check gradient vanishing
elif grad_norm < threshold:
gradient_issues.append(f"Gradient vanishing: {name}, norm={grad_norm:.2e}")
# Check NaN or Inf
if torch.isnan(param.grad).any():
gradient_issues.append(f"NaN gradient: {name}")
if torch.isinf(param.grad).any():
gradient_issues.append(f"Inf gradient: {name}")
return gradient_issues
@staticmethod
def check_weights(model):
"""Check weights"""
weight_issues = []
for name, param in model.named_parameters():
# Check NaN or Inf
if torch.isnan(param).any():
weight_issues.append(f"NaN weight: {name}")
if torch.isinf(param).any():
weight_issues.append(f"Inf weight: {name}")
# Check weight distribution
weight_std = param.std().item()
if weight_std < 1e-6:
weight_issues.append(f"Too small weight variance: {name}, std={weight_std:.2e}")
elif weight_std > 10:
weight_issues.append(f"Too large weight variance: {name}, std={weight_std:.2f}")
return weight_issues
@staticmethod
def analyze_activations(model, input_data):
"""Analyze activation values"""
activations = {}
def hook_fn(name):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
activations[name] = {
'mean': output.mean().item(),
'std': output.std().item(),
'min': output.min().item(),
'max': output.max().item(),
'has_nan': torch.isnan(output).any().item(),
'has_inf': torch.isinf(output).any().item()
}
return hook
# Register hooks
hooks = []
for name, module in model.named_modules():
if len(list(module.children())) == 0: # Leaf nodes
hook = module.register_forward_hook(hook_fn(name))
hooks.append(hook)
# Forward pass
model.eval()
with torch.no_grad():
_ = model(input_data)
# Remove hooks
for hook in hooks:
hook.remove()
return activations#Performance Optimization
#1. Memory Optimization
import gc
import torch
class MemoryOptimizer:
"""Memory optimization tools"""
@staticmethod
def clear_cache():
"""Clear GPU cache"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
@staticmethod
def get_memory_usage():
"""Get memory usage"""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3 # GB
cached = torch.cuda.memory_reserved() / 1024**3 # GB
return f"GPU memory: allocated {allocated:.2f}GB, cached {cached:.2f}GB"
else:
import psutil
memory = psutil.virtual_memory()
return f"CPU memory: used {memory.percent:.1f}%"
@staticmethod
def optimize_dataloader_memory(dataloader):
"""Optimize data loader memory usage"""
# Reduce prefetch factor
dataloader.prefetch_factor = 1
# Use fewer workers
if dataloader.num_workers > 4:
dataloader.num_workers = 4
return dataloader
@staticmethod
def use_gradient_checkpointing(model):
"""Use gradient checkpointing"""
from torch.utils.checkpoint import checkpoint
# Add checkpoint functionality to model
def checkpoint_forward(module, input):
return checkpoint(module, input)
# Apply to specified layers
for name, module in model.named_modules():
if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
module.forward = lambda x: checkpoint_forward(module, x)
return model#2. Computation Optimization
class ComputationOptimizer:
"""Computation optimization tools"""
@staticmethod
def optimize_model_for_inference(model):
"""Optimize model for inference"""
# Fuse BatchNorm
model = torch.jit.script(model)
# Freeze model
model.eval()
for param in model.parameters():
param.requires_grad = False
return model
@staticmethod
def enable_cudnn_benchmark():
"""Enable cuDNN benchmark"""
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
@staticmethod
def compile_model(model, mode='default'):
"""Compile model (PyTorch 2.0+)"""
if hasattr(torch, 'compile'):
return torch.compile(model, mode=mode)
return model#Summary
PyTorch best practices cover all aspects of deep learning projects:
- Code Organization: Clear project structure and modular design
- Data Processing: Efficient data loading and preprocessing pipelines
- Training Optimization: Advanced techniques like mixed precision and gradient accumulation
- Model Optimization: Model compression methods like quantization and pruning
- Debugging Monitoring: Complete training monitoring and model diagnostic tools
- Performance Optimization: Effective use of memory and computing resources
Following these best practices will help you build more efficient and reliable deep learning systems!