PyTorch Loss Functions and Optimizers
Loss Function Overview
Loss functions measure the difference between model predictions and true labels, and are core components for training neural networks. PyTorch provides rich loss functions in the torch.nn module.
Common Loss Functions
1. Classification Task Loss Functions
CrossEntropyLoss
python
import torch
import torch.nn as nn
import torch.nn.functional as F
# Multi-class cross-entropy loss
criterion = nn.CrossEntropyLoss()
# Example data
logits = torch.randn(32, 10) # Batch size 32, 10 classes
targets = torch.randint(0, 10, (32,)) # True labels
loss = criterion(logits, targets)
print(f"Cross-entropy loss: {loss.item():.4f}")
# Weighted cross-entropy (handle class imbalance)
class_weights = torch.tensor([1.0, 2.0, 1.5, 1.0, 3.0, 1.0, 1.0, 2.0, 1.0, 1.0])
weighted_criterion = nn.CrossEntropyLoss(weight=class_weights)
weighted_loss = weighted_criterion(logits, targets)BCE Loss (BCELoss)
python
# Binary cross-entropy loss
bce_criterion = nn.BCELoss()
# Need to apply sigmoid first
sigmoid_outputs = torch.sigmoid(torch.randn(32, 1))
binary_targets = torch.randint(0, 2, (32, 1)).float()
bce_loss = bce_criterion(sigmoid_outputs, binary_targets)
# BCEWithLogitsLoss (built-in sigmoid, more numerically stable)
bce_logits_criterion = nn.BCEWithLogitsLoss()
raw_logits = torch.randn(32, 1)
bce_logits_loss = bce_logits_criterion(raw_logits, binary_targets)
print(f"BCE loss: {bce_loss.item():.4f}")
print(f"BCE with Logits loss: {bce_logits_loss.item():.4f}")2. Regression Task Loss Functions
MSE Loss (MSELoss)
python
# Mean squared error loss
mse_criterion = nn.MSELoss()
predictions = torch.randn(32, 1)
targets = torch.randn(32, 1)
mse_loss = mse_criterion(predictions, targets)
print(f"MSE loss: {mse_loss.item():.4f}")
# Mean absolute error loss
mae_criterion = nn.L1Loss()
mae_loss = mae_criterion(predictions, targets)
print(f"MAE loss: {mae_loss.item():.4f}")Optimizers
1. Basic Optimizers
SGD (Stochastic Gradient Descent)
python
import torch.optim as optim
# Create model
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# SGD optimizer
sgd_optimizer = optim.SGD(
model.parameters(),
lr=0.01, # Learning rate
momentum=0.9, # Momentum
weight_decay=1e-4, # Weight decay (L2 regularization)
nesterov=True # Nesterov momentum
)
print(f"SGD optimizer: {sgd_optimizer}")Adam Optimizer
python
# Adam optimizer
adam_optimizer = optim.Adam(
model.parameters(),
lr=0.001, # Learning rate
betas=(0.9, 0.999), # Momentum parameters
eps=1e-8, # Numerical stability parameter
weight_decay=1e-4, # Weight decay
amsgrad=False # Whether to use AMSGrad variant
)
# AdamW optimizer (decoupled weight decay)
adamw_optimizer = optim.AdamW(
model.parameters(),
lr=0.001,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01, # Weight decay is more effective in AdamW
amsgrad=False
)Learning Rate Scheduling
1. Basic Schedulers
python
from torch.optim.lr_scheduler import *
# Step decay
step_scheduler = StepLR(
optimizer=adam_optimizer,
step_size=30, # Decay every 30 epochs
gamma=0.1 # Decay factor
)
# Multi-step decay
multistep_scheduler = MultiStepLR(
optimizer=adam_optimizer,
milestones=[30, 60, 90], # Decay at these epochs
gamma=0.1
)
# Exponential decay
exp_scheduler = ExponentialLR(
optimizer=adam_optimizer,
gamma=0.95 # Multiply by 0.95 each epoch
)2. Adaptive Schedulers
python
# Scheduler based on validation loss
plateau_scheduler = ReduceLROnPlateau(
optimizer=adam_optimizer,
mode='min', # Monitor if metric should decrease
factor=0.5, # Decay factor
patience=10, # Epochs to wait
verbose=True, # Print information
threshold=0.0001, # Improvement threshold
min_lr=0, # Minimum learning rate
eps=1e-8
)
# Cosine annealing scheduler
cosine_scheduler = CosineAnnealingLR(
optimizer=adam_optimizer,
T_max=100, # Maximum epochs
eta_min=0 # Minimum learning rate
)
# Cosine annealing with restarts
cosine_restart_scheduler = CosineAnnealingWarmRestarts(
optimizer=adam_optimizer,
T_0=10, # Epochs for first restart
T_mult=2, # Multiplier for restart periods
eta_min=0
)Training Loop Example
1. Basic Training Loop
python
def train_epoch(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
output = model(data)
loss = criterion(output, target)
# Backward pass
loss.backward()
# Gradient clipping (optional)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Update parameters
optimizer.step()
# Statistics
total_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total += target.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = 100. * correct / total
return avg_loss, accuracy
def validate(model, dataloader, criterion, device):
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, target in dataloader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
total_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total += target.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = 100. * correct / total
return avg_loss, accuracyAdvanced Training Techniques
1. Gradient Accumulation
python
def train_with_gradient_accumulation(model, dataloader, criterion, optimizer,
accumulation_steps=4):
model.train()
optimizer.zero_grad()
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
# Forward pass
output = model(data)
loss = criterion(output, target)
# Scale loss
loss = loss / accumulation_steps
# Backward pass
loss.backward()
# Update parameters every accumulation_steps steps
if (batch_idx + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()2. Mixed Precision Training
python
from torch.cuda.amp import GradScaler, autocast
def train_with_mixed_precision(model, dataloader, criterion, optimizer):
scaler = GradScaler()
model.train()
for data, target in dataloader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# Use autocast for forward pass
with autocast():
output = model(data)
loss = criterion(output, target)
# Scale loss and backpropagate
scaler.scale(loss).backward()
# Update parameters
scaler.step(optimizer)
scaler.update()3. Early Stopping Mechanism
python
class EarlyStopping:
def __init__(self, patience=7, min_delta=0, restore_best_weights=True):
self.patience = patience
self.min_delta = min_delta
self.restore_best_weights = restore_best_weights
self.best_loss = None
self.counter = 0
self.best_weights = None
def __call__(self, val_loss, model):
if self.best_loss is None:
self.best_loss = val_loss
self.save_checkpoint(model)
elif val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
self.save_checkpoint(model)
else:
self.counter += 1
if self.counter >= self.patience:
if self.restore_best_weights:
model.load_state_dict(self.best_weights)
return True
return False
def save_checkpoint(self, model):
self.best_weights = model.state_dict().copy()
# Use early stopping
early_stopping = EarlyStopping(patience=10)
for epoch in range(num_epochs):
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_acc = validate(model, val_loader, criterion, device)
if early_stopping(val_loss, model):
print(f"Early stopping at epoch {epoch}")
breakSummary
Loss functions and optimizers are core components of deep learning training:
- Loss Function Selection: Choose appropriate loss functions based on task type
- Optimizer Selection: Understand characteristics and application scenarios of different optimizers
- Learning Rate Scheduling: Use appropriate learning rate scheduling strategies
- Training Techniques: Master gradient accumulation, mixed precision, early stopping, etc.
- Debugging Optimization: Learn to diagnose and solve training problems
Mastering these concepts will help you train better deep learning models!