Skip to content

PyTorch Best Practices

Code Organization and Project Structure

project/
├── data/                   # Data directory
│   ├── raw/               # Raw data
│   ├── processed/         # Processed data
│   └── external/          # External data
├── models/                # Model definitions
│   ├── __init__.py
│   ├── base_model.py      # Base model class
│   ├── resnet.py          # Specific model implementation
│   └── transformer.py
├── src/                   # Source code
│   ├── __init__.py
│   ├── data/              # Data processing
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   └── transforms.py
│   ├── training/          # Training related
│   │   ├── __init__.py
│   │   ├── trainer.py
│   │   └── losses.py
│   └── utils/             # Utility functions
│       ├── __init__.py
│       ├── metrics.py
│       └── visualization.py
├── configs/               # Configuration files
│   ├── base_config.yaml
│   └── experiment_configs/
├── experiments/           # Experiment records
├── notebooks/             # Jupyter notebooks
├── tests/                 # Test code
├── requirements.txt       # Dependencies
├── setup.py              # Installation script
└── README.md             # Project description

2. Base Model Class Design

python
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional

class BaseModel(nn.Module, ABC):
    """Abstract base model class"""
    
    def __init__(self, config: Dict[str, Any]):
        super(BaseModel, self).__init__()
        self.config = config
        self._build_model()
    
    @abstractmethod
    def _build_model(self):
        """Build model architecture"""
        pass
    
    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass"""
        pass
    
    def get_num_parameters(self) -> int:
        """Get model parameter count"""
        return sum(p.numel() for p in self.parameters())
    
    def get_num_trainable_parameters(self) -> int:
        """Get trainable parameter count"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def freeze_parameters(self, module_names: Optional[list] = None):
        """Freeze specified module parameters"""
        if module_names is None:
            # Freeze all parameters
            for param in self.parameters():
                param.requires_grad = False
        else:
            # Freeze specified modules
            for name, module in self.named_modules():
                if any(module_name in name for module_name in module_names):
                    for param in module.parameters():
                        param.requires_grad = False
    
    def unfreeze_parameters(self, module_names: Optional[list] = None):
        """Unfreeze specified module parameters"""
        if module_names is None:
            # Unfreeze all parameters
            for param in self.parameters():
                param.requires_grad = True
        else:
            # Unfreeze specified modules
            for name, module in self.named_modules():
                if any(module_name in name for module_name in module_names):
                    for param in module.parameters():
                        param.requires_grad = True
    
    def save_checkpoint(self, filepath: str, epoch: int, optimizer_state: Dict = None, 
                       scheduler_state: Dict = None, **kwargs):
        """Save checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.state_dict(),
            'config': self.config,
            'num_parameters': self.get_num_parameters(),
            **kwargs
        }
        
        if optimizer_state:
            checkpoint['optimizer_state_dict'] = optimizer_state
        if scheduler_state:
            checkpoint['scheduler_state_dict'] = scheduler_state
        
        torch.save(checkpoint, filepath)
    
    @classmethod
    def load_checkpoint(cls, filepath: str, map_location: str = 'cpu'):
        """Load checkpoint"""
        checkpoint = torch.load(filepath, map_location=map_location)
        model = cls(checkpoint['config'])
        model.load_state_dict(checkpoint['model_state_dict'])
        return model, checkpoint

# Specific model implementation example
class ResNetClassifier(BaseModel):
    def _build_model(self):
        from torchvision.models import resnet18
        self.backbone = resnet18(pretrained=self.config.get('pretrained', True))
        self.backbone.fc = nn.Linear(
            self.backbone.fc.in_features, 
            self.config['num_classes']
        )
    
    def forward(self, x):
        return self.backbone(x)

Data Processing Best Practices

1. Efficient Data Loading

python
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import Tuple, List, Optional
import multiprocessing as mp

class OptimizedDataset(Dataset):
    """Optimized dataset class"""
    
    def __init__(self, data_path: str, transform=None, cache_size: int = 1000):
        self.data_path = data_path
        self.transform = transform
        self.cache_size = cache_size
        self.cache = {}
        self.access_count = {}
        
        # Preload index information
        self._load_index()
    
    def _load_index(self):
        """Load data index to avoid reading files every time"""
        # Implement specific index loading logic
        pass
    
    def __getitem__(self, idx):
        # Cache mechanism
        if idx in self.cache:
            self.access_count[idx] += 1
            data = self.cache[idx]
        else:
            data = self._load_data(idx)
            self._update_cache(idx, data)
        
        if self.transform:
            data = self.transform(data)
        
        return data
    
    def _load_data(self, idx):
        """Load single data sample"""
        # Implement specific data loading logic
        pass
    
    def _update_cache(self, idx, data):
        """Update cache"""
        if len(self.cache) >= self.cache_size:
            # Remove least accessed item
            lru_idx = min(self.access_count, key=self.access_count.get)
            del self.cache[lru_idx]
            del self.access_count[lru_idx]
        
        self.cache[idx] = data
        self.access_count[idx] = 1

def create_optimized_dataloader(dataset, batch_size: int, num_workers: Optional[int] = None,
                               pin_memory: bool = True, persistent_workers: bool = True):
    """Create optimized data loader"""
    if num_workers is None:
        num_workers = min(8, mp.cpu_count())
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory and torch.cuda.is_available(),
        persistent_workers=persistent_workers and num_workers > 0,
        prefetch_factor=2 if num_workers > 0 else 2,
        drop_last=True  # Keep batch size consistent
    )

2. Data Preprocessing Pipeline

python
import torchvision.transforms as transforms
from typing import Union, List

class DataPreprocessor:
    """Data preprocessing pipeline"""
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.train_transform = self._build_train_transform()
        self.val_transform = self._build_val_transform()
    
    def _build_train_transform(self):
        """Build data transforms for training"""
        transforms_list = []
        
        # Basic transforms
        if self.config.get('resize'):
            transforms_list.append(transforms.Resize(self.config['resize']))
        
        # Data augmentation
        if self.config.get('random_crop'):
            transforms_list.append(
                transforms.RandomCrop(
                    self.config['random_crop']['size'],
                    padding=self.config['random_crop'].get('padding', 4)
                )
            )
        
        if self.config.get('random_horizontal_flip'):
            transforms_list.append(
                transforms.RandomHorizontalFlip(
                    p=self.config['random_horizontal_flip']
                )
            )
        
        if self.config.get('color_jitter'):
            transforms_list.append(
                transforms.ColorJitter(**self.config['color_jitter'])
            )
        
        # Convert to tensor and normalize
        transforms_list.extend([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=self.config['normalize']['mean'],
                std=self.config['normalize']['std']
            )
        ])
        
        # Advanced augmentation
        if self.config.get('random_erasing'):
            transforms_list.append(
                transforms.RandomErasing(**self.config['random_erasing'])
            )
        
        return transforms.Compose(transforms_list)
    
    def _build_val_transform(self):
        """Build data transforms for validation"""
        transforms_list = []
        
        if self.config.get('resize'):
            transforms_list.append(transforms.Resize(self.config['resize']))
        
        if self.config.get('center_crop'):
            transforms_list.append(transforms.CenterCrop(self.config['center_crop']))
        
        transforms_list.extend([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=self.config['normalize']['mean'],
                std=self.config['normalize']['std']
            )
        ])
        
        return transforms.Compose(transforms_list)

Training Optimization Techniques

1. Mixed Precision Training

python
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F

class MixedPrecisionTrainer:
    """Mixed precision trainer"""
    
    def __init__(self, model, optimizer, criterion, device):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.scaler = GradScaler()
    
    def train_step(self, data, target):
        """Single training step"""
        data, target = data.to(self.device), target.to(self.device)
        
        self.optimizer.zero_grad()
        
        # Use autocast for forward pass
        with autocast():
            output = self.model(data)
            loss = self.criterion(output, target)
        
        # Scale loss and backpropagate
        self.scaler.scale(loss).backward()
        
        # Gradient clipping
        self.scaler.unscale_(self.optimizer)
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        
        # Update parameters
        self.scaler.step(self.optimizer)
        self.scaler.update()
        
        return loss.item()
    
    def validate_step(self, data, target):
        """Validation step"""
        data, target = data.to(self.device), target.to(self.device)
        
        with torch.no_grad(), autocast():
            output = self.model(data)
            loss = self.criterion(output, target)
        
        return loss.item(), output

2. Gradient Accumulation

python
class GradientAccumulationTrainer:
    """Gradient accumulation trainer"""
    
    def __init__(self, model, optimizer, criterion, device, accumulation_steps=4):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.accumulation_steps = accumulation_steps
    
    def train_epoch(self, dataloader):
        """Train one epoch"""
        self.model.train()
        total_loss = 0
        self.optimizer.zero_grad()
        
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(self.device), target.to(self.device)
            
            # Forward pass
            output = self.model(data)
            loss = self.criterion(output, target)
            
            # Scale loss
            loss = loss / self.accumulation_steps
            
            # Backward pass
            loss.backward()
            
            total_loss += loss.item() * self.accumulation_steps
            
            # Update parameters every accumulation_steps steps
            if (batch_idx + 1) % self.accumulation_steps == 0:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                
                # Update parameters
                self.optimizer.step()
                self.optimizer.zero_grad()
        
        return total_loss / len(dataloader)

Model Optimization and Deployment

1. Model Quantization

python
import torch.quantization as quantization

def quantize_model(model, calibration_dataloader, device):
    """Quantize model"""
    # Set quantization configuration
    model.qconfig = quantization.get_default_qconfig('fbgemm')
    
    # Prepare for quantization
    model_prepared = quantization.prepare(model, inplace=False)
    
    # Calibrate
    model_prepared.eval()
    with torch.no_grad():
        for data, _ in calibration_dataloader:
            data = data.to(device)
            model_prepared(data)
    
    # Convert to quantized model
    model_quantized = quantization.convert(model_prepared, inplace=False)
    
    return model_quantized

def compare_model_sizes(model_fp32, model_quantized):
    """Compare model sizes"""
    def get_model_size(model):
        torch.save(model.state_dict(), "temp.p")
        size = os.path.getsize("temp.p")
        os.remove("temp.p")
        return size
    
    fp32_size = get_model_size(model_fp32)
    quantized_size = get_model_size(model_quantized)
    
    print(f"FP32 model size: {fp32_size / 1024 / 1024:.2f} MB")
    print(f"Quantized model size: {quantized_size / 1024 / 1024:.2f} MB")
    print(f"Compression ratio: {fp32_size / quantized_size:.2f}x")

2. Model Pruning

python
import torch.nn.utils.prune as prune

def prune_model(model, pruning_ratio=0.2):
    """Prune model"""
    # Collect all convolution and linear layers
    modules_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            modules_to_prune.append((module, 'weight'))
    
    # Global unstructured pruning
    prune.global_unstructured(
        modules_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=pruning_ratio,
    )
    
    # Remove pruning reparameterization
    for module, param_name in modules_to_prune:
        prune.remove(module, param_name)
    
    return model

def calculate_sparsity(model):
    """Calculate model sparsity"""
    total_params = 0
    zero_params = 0
    
    for param in model.parameters():
        total_params += param.numel()
        zero_params += (param == 0).sum().item()
    
    sparsity = zero_params / total_params
    print(f"Model sparsity: {sparsity:.2%}")
    return sparsity

Debugging and Monitoring

1. Training Monitoring

python
import wandb
from torch.utils.tensorboard import SummaryWriter
import time

class TrainingMonitor:
    """Training monitor"""
    
    def __init__(self, project_name, experiment_name, config):
        self.config = config
        
        # Initialize wandb
        if config.get('use_wandb', False):
            wandb.init(project=project_name, name=experiment_name, config=config)
        
        # Initialize tensorboard
        if config.get('use_tensorboard', False):
            self.writer = SummaryWriter(f'runs/{experiment_name}')
        
        self.metrics = {}
        self.start_time = time.time()
    
    def log_metrics(self, metrics, step, prefix=''):
        """Log metrics"""
        for key, value in metrics.items():
            metric_name = f"{prefix}/{key}" if prefix else key
            
            # Log to wandb
            if hasattr(self, 'wandb') and wandb.run:
                wandb.log({metric_name: value}, step=step)
            
            # Log to tensorboard
            if hasattr(self, 'writer'):
                self.writer.add_scalar(metric_name, value, step)
    
    def log_model_graph(self, model, input_sample):
        """Log model graph"""
        if hasattr(self, 'writer'):
            self.writer.add_graph(model, input_sample)
    
    def log_gradients(self, model, step):
        """Log gradient information"""
        if hasattr(self, 'writer'):
            for name, param in model.named_parameters():
                if param.grad is not None:
                    self.writer.add_histogram(f'gradients/{name}', param.grad, step)
                    self.writer.add_scalar(f'gradient_norms/{name}', 
                                         param.grad.norm().item(), step)
    
    def log_learning_rate(self, optimizer, step):
        """Log learning rate"""
        for i, param_group in enumerate(optimizer.param_groups):
            lr = param_group['lr']
            if hasattr(self, 'writer'):
                self.writer.add_scalar(f'learning_rate/group_{i}', lr, step)
    
    def close(self):
        """Close monitor"""
        if hasattr(self, 'writer'):
            self.writer.close()
        
        if wandb.run:
            wandb.finish()

2. Model Diagnostics

python
class ModelDiagnostics:
    """Model diagnostic tools"""
    
    @staticmethod
    def check_gradients(model, threshold=1e-7):
        """Check gradients"""
        gradient_issues = []
        
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad_norm = param.grad.norm().item()
                
                # Check gradient explosion
                if grad_norm > 100:
                    gradient_issues.append(f"Gradient explosion: {name}, norm={grad_norm:.2f}")
                
                # Check gradient vanishing
                elif grad_norm < threshold:
                    gradient_issues.append(f"Gradient vanishing: {name}, norm={grad_norm:.2e}")
                
                # Check NaN or Inf
                if torch.isnan(param.grad).any():
                    gradient_issues.append(f"NaN gradient: {name}")
                
                if torch.isinf(param.grad).any():
                    gradient_issues.append(f"Inf gradient: {name}")
        
        return gradient_issues
    
    @staticmethod
    def check_weights(model):
        """Check weights"""
        weight_issues = []
        
        for name, param in model.named_parameters():
            # Check NaN or Inf
            if torch.isnan(param).any():
                weight_issues.append(f"NaN weight: {name}")
            
            if torch.isinf(param).any():
                weight_issues.append(f"Inf weight: {name}")
            
            # Check weight distribution
            weight_std = param.std().item()
            if weight_std < 1e-6:
                weight_issues.append(f"Too small weight variance: {name}, std={weight_std:.2e}")
            elif weight_std > 10:
                weight_issues.append(f"Too large weight variance: {name}, std={weight_std:.2f}")
        
        return weight_issues
    
    @staticmethod
    def analyze_activations(model, input_data):
        """Analyze activation values"""
        activations = {}
        
        def hook_fn(name):
            def hook(module, input, output):
                if isinstance(output, torch.Tensor):
                    activations[name] = {
                        'mean': output.mean().item(),
                        'std': output.std().item(),
                        'min': output.min().item(),
                        'max': output.max().item(),
                        'has_nan': torch.isnan(output).any().item(),
                        'has_inf': torch.isinf(output).any().item()
                    }
            return hook
        
        # Register hooks
        hooks = []
        for name, module in model.named_modules():
            if len(list(module.children())) == 0:  # Leaf nodes
                hook = module.register_forward_hook(hook_fn(name))
                hooks.append(hook)
        
        # Forward pass
        model.eval()
        with torch.no_grad():
            _ = model(input_data)
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
        
        return activations

Performance Optimization

1. Memory Optimization

python
import gc
import torch

class MemoryOptimizer:
    """Memory optimization tools"""
    
    @staticmethod
    def clear_cache():
        """Clear GPU cache"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    
    @staticmethod
    def get_memory_usage():
        """Get memory usage"""
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**3  # GB
            cached = torch.cuda.memory_reserved() / 1024**3     # GB
            return f"GPU memory: allocated {allocated:.2f}GB, cached {cached:.2f}GB"
        else:
            import psutil
            memory = psutil.virtual_memory()
            return f"CPU memory: used {memory.percent:.1f}%"
    
    @staticmethod
    def optimize_dataloader_memory(dataloader):
        """Optimize data loader memory usage"""
        # Reduce prefetch factor
        dataloader.prefetch_factor = 1
        
        # Use fewer workers
        if dataloader.num_workers > 4:
            dataloader.num_workers = 4
        
        return dataloader
    
    @staticmethod
    def use_gradient_checkpointing(model):
        """Use gradient checkpointing"""
        from torch.utils.checkpoint import checkpoint
        
        # Add checkpoint functionality to model
        def checkpoint_forward(module, input):
            return checkpoint(module, input)
        
        # Apply to specified layers
        for name, module in model.named_modules():
            if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
                module.forward = lambda x: checkpoint_forward(module, x)
        
        return model

2. Computation Optimization

python
class ComputationOptimizer:
    """Computation optimization tools"""
    
    @staticmethod
    def optimize_model_for_inference(model):
        """Optimize model for inference"""
        # Fuse BatchNorm
        model = torch.jit.script(model)
        
        # Freeze model
        model.eval()
        for param in model.parameters():
            param.requires_grad = False
        
        return model
    
    @staticmethod
    def enable_cudnn_benchmark():
        """Enable cuDNN benchmark"""
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            torch.backends.cudnn.deterministic = False
    
    @staticmethod
    def compile_model(model, mode='default'):
        """Compile model (PyTorch 2.0+)"""
        if hasattr(torch, 'compile'):
            return torch.compile(model, mode=mode)
        return model

Summary

PyTorch best practices cover all aspects of deep learning projects:

  1. Code Organization: Clear project structure and modular design
  2. Data Processing: Efficient data loading and preprocessing pipelines
  3. Training Optimization: Advanced techniques like mixed precision and gradient accumulation
  4. Model Optimization: Model compression methods like quantization and pruning
  5. Debugging Monitoring: Complete training monitoring and model diagnostic tools
  6. Performance Optimization: Effective use of memory and computing resources

Following these best practices will help you build more efficient and reliable deep learning systems!

Content is for learning and research only.