Skip to content

PyTorch Model Training and Validation

Training Process Overview

Training deep learning models is an iterative optimization process that includes forward pass, loss calculation, backward pass, and parameter updates. PyTorch provides flexible tools to implement this process.

python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Basic training loop structure
def basic_training_loop():
    for epoch in range(num_epochs):
        for batch_idx, (data, target) in enumerate(dataloader):
            # 1. Forward pass
            output = model(data)
            
            # 2. Calculate loss
            loss = criterion(output, target)
            
            # 3. Zero gradients
            optimizer.zero_grad()
            
            # 4. Backward pass
            loss.backward()
            
            # 5. Update parameters
            optimizer.step()

Complete Training Framework

1. Training Function

python
def train_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """Train one epoch"""
    model.train()  # Set to training mode
    
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Progress bar
    from tqdm import tqdm
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}')
    
    for batch_idx, (data, target) in enumerate(pbar):
        # Move data to device
        data, target = data.to(device), target.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        output = model(data)
        loss = criterion(output, target)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping (optional)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update parameters
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += target.size(0)
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

2. Validation Function

python
def validate_epoch(model, dataloader, criterion, device):
    """Validate model"""
    model.eval()  # Set to evaluation mode
    
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():  # Disable gradient computation
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            loss = criterion(output, target)
            
            running_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    
    val_loss = running_loss / len(dataloader)
    val_acc = 100. * correct / total
    
    return val_loss, val_acc

3. Complete Training Process

python
class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, 
                 device, save_dir='./checkpoints'):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.save_dir = save_dir
        
        # Create save directory
        os.makedirs(save_dir, exist_ok=True)
        
        # Training history
        self.train_losses = []
        self.train_accs = []
        self.val_losses = []
        self.val_accs = []
        
        # Best model tracking
        self.best_val_acc = 0.0
        self.best_epoch = 0
    
    def train(self, num_epochs, scheduler=None, early_stopping=None):
        """Complete training process"""
        print(f"Starting training, {num_epochs} epochs")
        print(f"Device: {self.device}")
        print(f"Training set size: {len(self.train_loader.dataset)}")
        print(f"Validation set size: {len(self.val_loader.dataset)}")
        print("-" * 50)
        
        for epoch in range(num_epochs):
            # Train
            train_loss, train_acc = train_epoch(
                self.model, self.train_loader, self.criterion, 
                self.optimizer, self.device, epoch
            )
            
            # Validate
            val_loss, val_acc = validate_epoch(
                self.model, self.val_loader, self.criterion, self.device
            )
            
            # Update learning rate
            if scheduler:
                if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(val_loss)
                else:
                    scheduler.step()
            
            # Record history
            self.train_losses.append(train_loss)
            self.train_accs.append(train_acc)
            self.val_losses.append(val_loss)
            self.val_accs.append(val_acc)
            
            # Print results
            current_lr = self.optimizer.param_groups[0]['lr']
            print(f'Epoch {epoch+1}/{num_epochs}:')
            print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
            print(f'  Learning Rate: {current_lr:.6f}')
            
            # Save best model
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.best_epoch = epoch
                self.save_checkpoint(epoch, is_best=True)
                print(f'  ✓ New best model! Validation accuracy: {val_acc:.2f}%')
            
            # Regular checkpoint saving
            if (epoch + 1) % 10 == 0:
                self.save_checkpoint(epoch)
            
            # Early stopping check
            if early_stopping:
                if early_stopping(val_loss, self.model):
                    print(f'Early stopping triggered, stopping training at epoch {epoch+1}')
                    break
            
            print("-" * 50)
        
        print(f'Training complete! Best validation accuracy: {self.best_val_acc:.2f}% (Epoch {self.best_epoch+1})')
        
        # Load best model
        self.load_best_model()
        
        return self.train_losses, self.train_accs, self.val_losses, self.val_accs
    
    def save_checkpoint(self, epoch, is_best=False):
        """Save checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'train_accs': self.train_accs,
            'val_losses': self.val_losses,
            'val_accs': self.val_accs,
            'best_val_acc': self.best_val_acc,
            'best_epoch': self.best_epoch
        }
        
        # Save current checkpoint
        checkpoint_path = os.path.join(self.save_dir, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save(checkpoint, checkpoint_path)
        
        # Save best model
        if is_best:
            best_path = os.path.join(self.save_dir, 'best_model.pth')
            torch.save(checkpoint, best_path)
    
    def load_best_model(self):
        """Load best model"""
        best_path = os.path.join(self.save_dir, 'best_model.pth')
        if os.path.exists(best_path):
            checkpoint = torch.load(best_path, map_location=self.device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Loaded best model (Epoch {checkpoint['best_epoch']+1})")

Training Techniques and Optimization

1. Learning Rate Scheduling

python
from torch.optim.lr_scheduler import *

def create_scheduler(optimizer, scheduler_type='cosine', **kwargs):
    """Create learning rate scheduler"""
    if scheduler_type == 'step':
        return StepLR(optimizer, step_size=kwargs.get('step_size', 30), 
                     gamma=kwargs.get('gamma', 0.1))
    
    elif scheduler_type == 'multistep':
        return MultiStepLR(optimizer, milestones=kwargs.get('milestones', [30, 60, 90]), 
                          gamma=kwargs.get('gamma', 0.1))
    
    elif scheduler_type == 'cosine':
        return CosineAnnealingLR(optimizer, T_max=kwargs.get('T_max', 100))
    
    elif scheduler_type == 'plateau':
        return ReduceLROnPlateau(optimizer, mode='min', factor=0.5, 
                               patience=kwargs.get('patience', 10))
    
    elif scheduler_type == 'warmup_cosine':
        return CosineAnnealingWarmRestarts(optimizer, T_0=kwargs.get('T_0', 10))
    
    else:
        raise ValueError(f"Unsupported scheduler type: {scheduler_type}")

# Usage example
scheduler = create_scheduler(optimizer, 'cosine', T_max=100)

2. Early Stopping Mechanism

python
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, restore_best_weights=True, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
    
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
            if self.verbose:
                print(f'Validation loss improved to {val_loss:.6f}')
        else:
            self.counter += 1
            if self.verbose:
                print(f'Validation loss not improved ({self.counter}/{self.patience})')
        
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
                if self.verbose:
                    print('Restored best weights')
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

# Use early stopping
early_stopping = EarlyStopping(patience=10, min_delta=0.001)

3. Gradient Accumulation

python
def train_with_gradient_accumulation(model, dataloader, criterion, optimizer, 
                                   device, accumulation_steps=4):
    """Training with gradient accumulation"""
    model.train()
    optimizer.zero_grad()
    
    running_loss = 0.0
    
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        
        # Forward pass
        output = model(data)
        loss = criterion(output, target)
        
        # Scale loss
        loss = loss / accumulation_steps
        
        # Backward pass
        loss.backward()
        
        running_loss += loss.item() * accumulation_steps
        
        # Update parameters every accumulation_steps steps
        if (batch_idx + 1) % accumulation_steps == 0:
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # Update parameters
            optimizer.step()
            optimizer.zero_grad()
    
    return running_loss / len(dataloader)

4. Mixed Precision Training

python
from torch.cuda.amp import GradScaler, autocast

def train_with_mixed_precision(model, dataloader, criterion, optimizer, device):
    """Mixed precision training"""
    model.train()
    scaler = GradScaler()
    
    running_loss = 0.0
    
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # Use autocast for forward pass
        with autocast():
            output = model(data)
            loss = criterion(output, target)
        
        # Scale loss and backward
        scaler.scale(loss).backward()
        
        # Gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update parameters
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
    
    return running_loss / len(dataloader)

Model Evaluation

1. Classification Metrics

python
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import numpy as np

def evaluate_classification(model, dataloader, device, num_classes):
    """Evaluate classification model"""
    model.eval()
    
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_targets, all_preds, average='weighted'
    )
    
    # Confusion matrix
    cm = confusion_matrix(all_targets, all_preds)
    
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'confusion_matrix': cm
    }

2. Visualize Training Process

python
import matplotlib.pyplot as plt

def plot_training_history(train_losses, train_accs, val_losses, val_accs):
    """Plot training history"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss curves
    ax1.plot(train_losses, label='Training Loss', color='blue')
    ax1.plot(val_losses, label='Validation Loss', color='red')
    ax1.set_title('Loss Curves')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Accuracy curves
    ax2.plot(train_accs, label='Training Accuracy', color='blue')
    ax2.plot(val_accs, label='Validation Accuracy', color='red')
    ax2.set_title('Accuracy Curves')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(cm, class_names):
    """Plot confusion matrix"""
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    
    # Add numerical labels
    thresh = cm.max() / 2.
    for i, j in np.ndindex(cm.shape):
        plt.text(j, i, format(cm[i, j], 'd'),
                horizontalalignment="center",
                color="white" if cm[i, j] > thresh else "black")
    
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()

Practical Application Example

1. CIFAR-10 Image Classification

python
import torchvision
import torchvision.transforms as transforms

# Data preparation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                       download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                         shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, 
                                      download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, 
                                         shuffle=False, num_workers=2)

# Model definition
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = create_scheduler(optimizer, 'cosine', T_max=100)

# Create trainer
trainer = Trainer(model, trainloader, testloader, criterion, optimizer, device)

# Start training
train_losses, train_accs, val_losses, val_accs = trainer.train(
    num_epochs=100, 
    scheduler=scheduler,
    early_stopping=EarlyStopping(patience=15)
)

# Visualize results
plot_training_history(train_losses, train_accs, val_losses, val_accs)

# Final evaluation
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']
metrics = evaluate_classification(model, testloader, device, 10)
plot_confusion_matrix(metrics['confusion_matrix'], class_names)

Debugging and Troubleshooting

1. Common Issue Diagnosis

python
def diagnose_training_issues(model, dataloader, criterion, optimizer, device):
    """Diagnose training issues"""
    model.train()
    
    # Check data
    data_batch, target_batch = next(iter(dataloader))
    print(f"Data shape: {data_batch.shape}")
    print(f"Target shape: {target_batch.shape}")
    print(f"Data range: [{data_batch.min():.3f}, {data_batch.max():.3f}]")
    print(f"Target range: [{target_batch.min()}, {target_batch.max()}]")
    
    # Check model output
    data_batch = data_batch.to(device)
    output = model(data_batch)
    print(f"Model output shape: {output.shape}")
    print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")
    
    # Check loss
    target_batch = target_batch.to(device)
    loss = criterion(output, target_batch)
    print(f"Initial loss: {loss.item():.4f}")
    
    # Check gradients
    loss.backward()
    total_norm = 0
    param_count = 0
    for name, param in model.named_parameters():
        if param.grad is not None:
            param_norm = param.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
            param_count += 1
            if param_norm > 10:  # Large gradient warning
                print(f"Warning: {name} gradient norm too large: {param_norm:.4f}")
    
    total_norm = total_norm ** (1. / 2)
    print(f"Total gradient norm: {total_norm:.4f}")
    print(f"Parameter count: {param_count}")
    
    # Check learning rate
    print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.6f}")

2. Performance Monitoring

python
import time
import psutil
import GPUtil

class PerformanceMonitor:
    def __init__(self):
        self.start_time = None
        self.epoch_times = []
    
    def start_epoch(self):
        self.start_time = time.time()
    
    def end_epoch(self):
        if self.start_time:
            epoch_time = time.time() - self.start_time
            self.epoch_times.append(epoch_time)
            return epoch_time
        return 0
    
    def get_system_info(self):
        # CPU usage
        cpu_percent = psutil.cpu_percent()
        
        # Memory usage
        memory = psutil.virtual_memory()
        memory_percent = memory.percent
        memory_used = memory.used / (1024**3)  # GB
        
        # GPU usage (if available)
        gpu_info = []
        try:
            gpus = GPUtil.getGPUs()
            for gpu in gpus:
                gpu_info.append({
                    'id': gpu.id,
                    'name': gpu.name,
                    'load': gpu.load * 100,
                    'memory_used': gpu.memoryUsed,
                    'memory_total': gpu.memoryTotal,
                    'temperature': gpu.temperature
                })
        except:
            pass
        
        return {
            'cpu_percent': cpu_percent,
            'memory_percent': memory_percent,
            'memory_used_gb': memory_used,
            'gpu_info': gpu_info
        }
    
    def print_performance_summary(self):
        if self.epoch_times:
            avg_time = sum(self.epoch_times) / len(self.epoch_times)
            print(f"Average epoch time: {avg_time:.2f} seconds")
            print(f"Estimated remaining time: {avg_time * (100 - len(self.epoch_times)):.2f} seconds")
        
        system_info = self.get_system_info()
        print(f"CPU usage: {system_info['cpu_percent']:.1f}%")
        print(f"Memory usage: {system_info['memory_percent']:.1f}%")
        
        for gpu in system_info['gpu_info']:
            print(f"GPU {gpu['id']} ({gpu['name']}): "
                  f"Load {gpu['load']:.1f}%, "
                  f"Memory {gpu['memory_used']}/{gpu['memory_total']}MB, "
                  f"Temperature {gpu['temperature']}°C")

# Use performance monitoring
monitor = PerformanceMonitor()

# In training loop
for epoch in range(num_epochs):
    monitor.start_epoch()
    
    # Training code...
    
    epoch_time = monitor.end_epoch()
    print(f"Epoch {epoch+1} took: {epoch_time:.2f} seconds")
    
    if epoch % 10 == 0:
        monitor.print_performance_summary()

Summary

Model training is the core of deep learning, and mastering it requires:

  1. Training Process: Understanding the complete flow of forward pass, loss calculation, backward pass, and parameter updates
  2. Training Techniques: Optimization techniques like learning rate scheduling, early stopping, gradient accumulation, and mixed precision training
  3. Model Evaluation: Using appropriate metrics to evaluate model performance
  4. Visualization Analysis: Analyzing training process and results through charts
  5. Debugging Skills: Diagnosing and solving common problems in training

Mastering these skills will help you train high-quality deep learning models!

Content is for learning and research only.