PyTorch Model Training and Validation
Training Process Overview
Training deep learning models is an iterative optimization process that includes forward pass, loss calculation, backward pass, and parameter updates. PyTorch provides flexible tools to implement this process.
python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# Basic training loop structure
def basic_training_loop():
for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate(dataloader):
# 1. Forward pass
output = model(data)
# 2. Calculate loss
loss = criterion(output, target)
# 3. Zero gradients
optimizer.zero_grad()
# 4. Backward pass
loss.backward()
# 5. Update parameters
optimizer.step()Complete Training Framework
1. Training Function
python
def train_epoch(model, dataloader, criterion, optimizer, device, epoch):
"""Train one epoch"""
model.train() # Set to training mode
running_loss = 0.0
correct = 0
total = 0
# Progress bar
from tqdm import tqdm
pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}')
for batch_idx, (data, target) in enumerate(pbar):
# Move data to device
data, target = data.to(device), target.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
output = model(data)
loss = criterion(output, target)
# Backward pass
loss.backward()
# Gradient clipping (optional)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Update parameters
optimizer.step()
# Statistics
running_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total += target.size(0)
# Update progress bar
pbar.set_postfix({
'Loss': f'{loss.item():.4f}',
'Acc': f'{100.*correct/total:.2f}%'
})
epoch_loss = running_loss / len(dataloader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc2. Validation Function
python
def validate_epoch(model, dataloader, criterion, device):
"""Validate model"""
model.eval() # Set to evaluation mode
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad(): # Disable gradient computation
for data, target in dataloader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
running_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total += target.size(0)
val_loss = running_loss / len(dataloader)
val_acc = 100. * correct / total
return val_loss, val_acc3. Complete Training Process
python
class Trainer:
def __init__(self, model, train_loader, val_loader, criterion, optimizer,
device, save_dir='./checkpoints'):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.criterion = criterion
self.optimizer = optimizer
self.device = device
self.save_dir = save_dir
# Create save directory
os.makedirs(save_dir, exist_ok=True)
# Training history
self.train_losses = []
self.train_accs = []
self.val_losses = []
self.val_accs = []
# Best model tracking
self.best_val_acc = 0.0
self.best_epoch = 0
def train(self, num_epochs, scheduler=None, early_stopping=None):
"""Complete training process"""
print(f"Starting training, {num_epochs} epochs")
print(f"Device: {self.device}")
print(f"Training set size: {len(self.train_loader.dataset)}")
print(f"Validation set size: {len(self.val_loader.dataset)}")
print("-" * 50)
for epoch in range(num_epochs):
# Train
train_loss, train_acc = train_epoch(
self.model, self.train_loader, self.criterion,
self.optimizer, self.device, epoch
)
# Validate
val_loss, val_acc = validate_epoch(
self.model, self.val_loader, self.criterion, self.device
)
# Update learning rate
if scheduler:
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(val_loss)
else:
scheduler.step()
# Record history
self.train_losses.append(train_loss)
self.train_accs.append(train_acc)
self.val_losses.append(val_loss)
self.val_accs.append(val_acc)
# Print results
current_lr = self.optimizer.param_groups[0]['lr']
print(f'Epoch {epoch+1}/{num_epochs}:')
print(f' Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
print(f' Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
print(f' Learning Rate: {current_lr:.6f}')
# Save best model
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
self.best_epoch = epoch
self.save_checkpoint(epoch, is_best=True)
print(f' ✓ New best model! Validation accuracy: {val_acc:.2f}%')
# Regular checkpoint saving
if (epoch + 1) % 10 == 0:
self.save_checkpoint(epoch)
# Early stopping check
if early_stopping:
if early_stopping(val_loss, self.model):
print(f'Early stopping triggered, stopping training at epoch {epoch+1}')
break
print("-" * 50)
print(f'Training complete! Best validation accuracy: {self.best_val_acc:.2f}% (Epoch {self.best_epoch+1})')
# Load best model
self.load_best_model()
return self.train_losses, self.train_accs, self.val_losses, self.val_accs
def save_checkpoint(self, epoch, is_best=False):
"""Save checkpoint"""
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': self.train_losses,
'train_accs': self.train_accs,
'val_losses': self.val_losses,
'val_accs': self.val_accs,
'best_val_acc': self.best_val_acc,
'best_epoch': self.best_epoch
}
# Save current checkpoint
checkpoint_path = os.path.join(self.save_dir, f'checkpoint_epoch_{epoch+1}.pth')
torch.save(checkpoint, checkpoint_path)
# Save best model
if is_best:
best_path = os.path.join(self.save_dir, 'best_model.pth')
torch.save(checkpoint, best_path)
def load_best_model(self):
"""Load best model"""
best_path = os.path.join(self.save_dir, 'best_model.pth')
if os.path.exists(best_path):
checkpoint = torch.load(best_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model (Epoch {checkpoint['best_epoch']+1})")Training Techniques and Optimization
1. Learning Rate Scheduling
python
from torch.optim.lr_scheduler import *
def create_scheduler(optimizer, scheduler_type='cosine', **kwargs):
"""Create learning rate scheduler"""
if scheduler_type == 'step':
return StepLR(optimizer, step_size=kwargs.get('step_size', 30),
gamma=kwargs.get('gamma', 0.1))
elif scheduler_type == 'multistep':
return MultiStepLR(optimizer, milestones=kwargs.get('milestones', [30, 60, 90]),
gamma=kwargs.get('gamma', 0.1))
elif scheduler_type == 'cosine':
return CosineAnnealingLR(optimizer, T_max=kwargs.get('T_max', 100))
elif scheduler_type == 'plateau':
return ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
patience=kwargs.get('patience', 10))
elif scheduler_type == 'warmup_cosine':
return CosineAnnealingWarmRestarts(optimizer, T_0=kwargs.get('T_0', 10))
else:
raise ValueError(f"Unsupported scheduler type: {scheduler_type}")
# Usage example
scheduler = create_scheduler(optimizer, 'cosine', T_max=100)2. Early Stopping Mechanism
python
class EarlyStopping:
def __init__(self, patience=7, min_delta=0, restore_best_weights=True, verbose=True):
self.patience = patience
self.min_delta = min_delta
self.restore_best_weights = restore_best_weights
self.verbose = verbose
self.best_loss = None
self.counter = 0
self.best_weights = None
def __call__(self, val_loss, model):
if self.best_loss is None:
self.best_loss = val_loss
self.save_checkpoint(model)
elif val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
self.save_checkpoint(model)
if self.verbose:
print(f'Validation loss improved to {val_loss:.6f}')
else:
self.counter += 1
if self.verbose:
print(f'Validation loss not improved ({self.counter}/{self.patience})')
if self.counter >= self.patience:
if self.restore_best_weights:
model.load_state_dict(self.best_weights)
if self.verbose:
print('Restored best weights')
return True
return False
def save_checkpoint(self, model):
self.best_weights = model.state_dict().copy()
# Use early stopping
early_stopping = EarlyStopping(patience=10, min_delta=0.001)3. Gradient Accumulation
python
def train_with_gradient_accumulation(model, dataloader, criterion, optimizer,
device, accumulation_steps=4):
"""Training with gradient accumulation"""
model.train()
optimizer.zero_grad()
running_loss = 0.0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
# Forward pass
output = model(data)
loss = criterion(output, target)
# Scale loss
loss = loss / accumulation_steps
# Backward pass
loss.backward()
running_loss += loss.item() * accumulation_steps
# Update parameters every accumulation_steps steps
if (batch_idx + 1) % accumulation_steps == 0:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Update parameters
optimizer.step()
optimizer.zero_grad()
return running_loss / len(dataloader)4. Mixed Precision Training
python
from torch.cuda.amp import GradScaler, autocast
def train_with_mixed_precision(model, dataloader, criterion, optimizer, device):
"""Mixed precision training"""
model.train()
scaler = GradScaler()
running_loss = 0.0
for data, target in dataloader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# Use autocast for forward pass
with autocast():
output = model(data)
loss = criterion(output, target)
# Scale loss and backward
scaler.scale(loss).backward()
# Gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Update parameters
scaler.step(optimizer)
scaler.update()
running_loss += loss.item()
return running_loss / len(dataloader)Model Evaluation
1. Classification Metrics
python
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import numpy as np
def evaluate_classification(model, dataloader, device, num_classes):
"""Evaluate classification model"""
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
for data, target in dataloader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
all_preds.extend(pred.cpu().numpy())
all_targets.extend(target.cpu().numpy())
# Calculate metrics
accuracy = accuracy_score(all_targets, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(
all_targets, all_preds, average='weighted'
)
# Confusion matrix
cm = confusion_matrix(all_targets, all_preds)
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
'confusion_matrix': cm
}2. Visualize Training Process
python
import matplotlib.pyplot as plt
def plot_training_history(train_losses, train_accs, val_losses, val_accs):
"""Plot training history"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# Loss curves
ax1.plot(train_losses, label='Training Loss', color='blue')
ax1.plot(val_losses, label='Validation Loss', color='red')
ax1.set_title('Loss Curves')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)
# Accuracy curves
ax2.plot(train_accs, label='Training Accuracy', color='blue')
ax2.plot(val_accs, label='Validation Accuracy', color='red')
ax2.set_title('Accuracy Curves')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()
ax2.grid(True)
plt.tight_layout()
plt.show()
def plot_confusion_matrix(cm, class_names):
"""Plot confusion matrix"""
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
# Add numerical labels
thresh = cm.max() / 2.
for i, j in np.ndindex(cm.shape):
plt.text(j, i, format(cm[i, j], 'd'),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()Practical Application Example
1. CIFAR-10 Image Classification
python
import torchvision
import torchvision.transforms as transforms
# Data preparation
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
shuffle=False, num_workers=2)
# Model definition
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(256, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = create_scheduler(optimizer, 'cosine', T_max=100)
# Create trainer
trainer = Trainer(model, trainloader, testloader, criterion, optimizer, device)
# Start training
train_losses, train_accs, val_losses, val_accs = trainer.train(
num_epochs=100,
scheduler=scheduler,
early_stopping=EarlyStopping(patience=15)
)
# Visualize results
plot_training_history(train_losses, train_accs, val_losses, val_accs)
# Final evaluation
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
metrics = evaluate_classification(model, testloader, device, 10)
plot_confusion_matrix(metrics['confusion_matrix'], class_names)Debugging and Troubleshooting
1. Common Issue Diagnosis
python
def diagnose_training_issues(model, dataloader, criterion, optimizer, device):
"""Diagnose training issues"""
model.train()
# Check data
data_batch, target_batch = next(iter(dataloader))
print(f"Data shape: {data_batch.shape}")
print(f"Target shape: {target_batch.shape}")
print(f"Data range: [{data_batch.min():.3f}, {data_batch.max():.3f}]")
print(f"Target range: [{target_batch.min()}, {target_batch.max()}]")
# Check model output
data_batch = data_batch.to(device)
output = model(data_batch)
print(f"Model output shape: {output.shape}")
print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")
# Check loss
target_batch = target_batch.to(device)
loss = criterion(output, target_batch)
print(f"Initial loss: {loss.item():.4f}")
# Check gradients
loss.backward()
total_norm = 0
param_count = 0
for name, param in model.named_parameters():
if param.grad is not None:
param_norm = param.grad.data.norm(2)
total_norm += param_norm.item() ** 2
param_count += 1
if param_norm > 10: # Large gradient warning
print(f"Warning: {name} gradient norm too large: {param_norm:.4f}")
total_norm = total_norm ** (1. / 2)
print(f"Total gradient norm: {total_norm:.4f}")
print(f"Parameter count: {param_count}")
# Check learning rate
print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.6f}")2. Performance Monitoring
python
import time
import psutil
import GPUtil
class PerformanceMonitor:
def __init__(self):
self.start_time = None
self.epoch_times = []
def start_epoch(self):
self.start_time = time.time()
def end_epoch(self):
if self.start_time:
epoch_time = time.time() - self.start_time
self.epoch_times.append(epoch_time)
return epoch_time
return 0
def get_system_info(self):
# CPU usage
cpu_percent = psutil.cpu_percent()
# Memory usage
memory = psutil.virtual_memory()
memory_percent = memory.percent
memory_used = memory.used / (1024**3) # GB
# GPU usage (if available)
gpu_info = []
try:
gpus = GPUtil.getGPUs()
for gpu in gpus:
gpu_info.append({
'id': gpu.id,
'name': gpu.name,
'load': gpu.load * 100,
'memory_used': gpu.memoryUsed,
'memory_total': gpu.memoryTotal,
'temperature': gpu.temperature
})
except:
pass
return {
'cpu_percent': cpu_percent,
'memory_percent': memory_percent,
'memory_used_gb': memory_used,
'gpu_info': gpu_info
}
def print_performance_summary(self):
if self.epoch_times:
avg_time = sum(self.epoch_times) / len(self.epoch_times)
print(f"Average epoch time: {avg_time:.2f} seconds")
print(f"Estimated remaining time: {avg_time * (100 - len(self.epoch_times)):.2f} seconds")
system_info = self.get_system_info()
print(f"CPU usage: {system_info['cpu_percent']:.1f}%")
print(f"Memory usage: {system_info['memory_percent']:.1f}%")
for gpu in system_info['gpu_info']:
print(f"GPU {gpu['id']} ({gpu['name']}): "
f"Load {gpu['load']:.1f}%, "
f"Memory {gpu['memory_used']}/{gpu['memory_total']}MB, "
f"Temperature {gpu['temperature']}°C")
# Use performance monitoring
monitor = PerformanceMonitor()
# In training loop
for epoch in range(num_epochs):
monitor.start_epoch()
# Training code...
epoch_time = monitor.end_epoch()
print(f"Epoch {epoch+1} took: {epoch_time:.2f} seconds")
if epoch % 10 == 0:
monitor.print_performance_summary()Summary
Model training is the core of deep learning, and mastering it requires:
- Training Process: Understanding the complete flow of forward pass, loss calculation, backward pass, and parameter updates
- Training Techniques: Optimization techniques like learning rate scheduling, early stopping, gradient accumulation, and mixed precision training
- Model Evaluation: Using appropriate metrics to evaluate model performance
- Visualization Analysis: Analyzing training process and results through charts
- Debugging Skills: Diagnosing and solving common problems in training
Mastering these skills will help you train high-quality deep learning models!