Skip to content

PyTorch Loss Functions and Optimizers

Loss Function Overview

Loss functions measure the difference between model predictions and true labels, and are core components for training neural networks. PyTorch provides rich loss functions in the torch.nn module.

Common Loss Functions

1. Classification Task Loss Functions

CrossEntropyLoss

python
import torch
import torch.nn as nn
import torch.nn.functional as F

# Multi-class cross-entropy loss
criterion = nn.CrossEntropyLoss()

# Example data
logits = torch.randn(32, 10)  # Batch size 32, 10 classes
targets = torch.randint(0, 10, (32,))  # True labels

loss = criterion(logits, targets)
print(f"Cross-entropy loss: {loss.item():.4f}")

# Weighted cross-entropy (handle class imbalance)
class_weights = torch.tensor([1.0, 2.0, 1.5, 1.0, 3.0, 1.0, 1.0, 2.0, 1.0, 1.0])
weighted_criterion = nn.CrossEntropyLoss(weight=class_weights)
weighted_loss = weighted_criterion(logits, targets)

BCE Loss (BCELoss)

python
# Binary cross-entropy loss
bce_criterion = nn.BCELoss()

# Need to apply sigmoid first
sigmoid_outputs = torch.sigmoid(torch.randn(32, 1))
binary_targets = torch.randint(0, 2, (32, 1)).float()

bce_loss = bce_criterion(sigmoid_outputs, binary_targets)

# BCEWithLogitsLoss (built-in sigmoid, more numerically stable)
bce_logits_criterion = nn.BCEWithLogitsLoss()
raw_logits = torch.randn(32, 1)
bce_logits_loss = bce_logits_criterion(raw_logits, binary_targets)

print(f"BCE loss: {bce_loss.item():.4f}")
print(f"BCE with Logits loss: {bce_logits_loss.item():.4f}")

2. Regression Task Loss Functions

MSE Loss (MSELoss)

python
# Mean squared error loss
mse_criterion = nn.MSELoss()

predictions = torch.randn(32, 1)
targets = torch.randn(32, 1)

mse_loss = mse_criterion(predictions, targets)
print(f"MSE loss: {mse_loss.item():.4f}")

# Mean absolute error loss
mae_criterion = nn.L1Loss()
mae_loss = mae_criterion(predictions, targets)
print(f"MAE loss: {mae_loss.item():.4f}")

Optimizers

1. Basic Optimizers

SGD (Stochastic Gradient Descent)

python
import torch.optim as optim

# Create model
model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# SGD optimizer
sgd_optimizer = optim.SGD(
    model.parameters(),
    lr=0.01,           # Learning rate
    momentum=0.9,        # Momentum
    weight_decay=1e-4,    # Weight decay (L2 regularization)
    nesterov=True        # Nesterov momentum
)

print(f"SGD optimizer: {sgd_optimizer}")

Adam Optimizer

python
# Adam optimizer
adam_optimizer = optim.Adam(
    model.parameters(),
    lr=0.001,                    # Learning rate
    betas=(0.9, 0.999),         # Momentum parameters
    eps=1e-8,                    # Numerical stability parameter
    weight_decay=1e-4,           # Weight decay
    amsgrad=False                # Whether to use AMSGrad variant
)

# AdamW optimizer (decoupled weight decay)
adamw_optimizer = optim.AdamW(
    model.parameters(),
    lr=0.001,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0.01,  # Weight decay is more effective in AdamW
    amsgrad=False
)

Learning Rate Scheduling

1. Basic Schedulers

python
from torch.optim.lr_scheduler import *

# Step decay
step_scheduler = StepLR(
    optimizer=adam_optimizer,
    step_size=30,    # Decay every 30 epochs
    gamma=0.1        # Decay factor
)

# Multi-step decay
multistep_scheduler = MultiStepLR(
    optimizer=adam_optimizer,
    milestones=[30, 60, 90],  # Decay at these epochs
    gamma=0.1
)

# Exponential decay
exp_scheduler = ExponentialLR(
    optimizer=adam_optimizer,
    gamma=0.95  # Multiply by 0.95 each epoch
)

2. Adaptive Schedulers

python
# Scheduler based on validation loss
plateau_scheduler = ReduceLROnPlateau(
    optimizer=adam_optimizer,
    mode='min',        # Monitor if metric should decrease
    factor=0.5,        # Decay factor
    patience=10,       # Epochs to wait
    verbose=True,      # Print information
    threshold=0.0001,  # Improvement threshold
    min_lr=0,          # Minimum learning rate
    eps=1e-8
)

# Cosine annealing scheduler
cosine_scheduler = CosineAnnealingLR(
    optimizer=adam_optimizer,
    T_max=100,    # Maximum epochs
    eta_min=0      # Minimum learning rate
)

# Cosine annealing with restarts
cosine_restart_scheduler = CosineAnnealingWarmRestarts(
    optimizer=adam_optimizer,
    T_0=10,       # Epochs for first restart
    T_mult=2,      # Multiplier for restart periods
    eta_min=0
)

Training Loop Example

1. Basic Training Loop

python
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        output = model(data)
        loss = criterion(output, target)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping (optional)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update parameters
        optimizer.step()
        
        # Statistics
        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += target.size(0)
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy

Advanced Training Techniques

1. Gradient Accumulation

python
def train_with_gradient_accumulation(model, dataloader, criterion, optimizer, 
                                   accumulation_steps=4):
    model.train()
    optimizer.zero_grad()
    
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        
        # Forward pass
        output = model(data)
        loss = criterion(output, target)
        
        # Scale loss
        loss = loss / accumulation_steps
        
        # Backward pass
        loss.backward()
        
        # Update parameters every accumulation_steps steps
        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

2. Mixed Precision Training

python
from torch.cuda.amp import GradScaler, autocast

def train_with_mixed_precision(model, dataloader, criterion, optimizer):
    scaler = GradScaler()
    model.train()
    
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # Use autocast for forward pass
        with autocast():
            output = model(data)
            loss = criterion(output, target)
        
        # Scale loss and backpropagate
        scaler.scale(loss).backward()
        
        # Update parameters
        scaler.step(optimizer)
        scaler.update()

3. Early Stopping Mechanism

python
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
    
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
        
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

# Use early stopping
early_stopping = EarlyStopping(patience=10)

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    if early_stopping(val_loss, model):
        print(f"Early stopping at epoch {epoch}")
        break

Summary

Loss functions and optimizers are core components of deep learning training:

  1. Loss Function Selection: Choose appropriate loss functions based on task type
  2. Optimizer Selection: Understand characteristics and application scenarios of different optimizers
  3. Learning Rate Scheduling: Use appropriate learning rate scheduling strategies
  4. Training Techniques: Master gradient accumulation, mixed precision, early stopping, etc.
  5. Debugging Optimization: Learn to diagnose and solve training problems

Mastering these concepts will help you train better deep learning models!

Content is for learning and research only.