Skip to content

PyTorch Distributed Training

Distributed Training Overview

Distributed training is a key technology for handling large-scale deep learning tasks, enabling acceleration of training processes by utilizing multiple GPUs or machines. PyTorch provides multiple distributed training solutions.

Distributed Training Basics

1. Basic Concepts

python
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import os

# Key concepts in distributed training:
# - World Size: Total number of processes
# - Rank: Global rank of current process
# - Local Rank: Process rank within current node
# - Backend: Communication backend (nccl, gloo, mpi)

def setup_distributed(rank, world_size, backend='nccl'):
    """Initialize distributed environment"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize process group
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    
    # Set CUDA device
    torch.cuda.set_device(rank)

def cleanup_distributed():
    """Clean up distributed environment"""
    dist.destroy_process_group()

2. Data Parallel (DataParallel)

python
import torch.nn as nn

# Simple data parallel (single machine, multiple GPUs)
class SimpleDataParallel:
    def __init__(self, model, device_ids=None):
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        
        self.model = nn.DataParallel(model, device_ids=device_ids)
        self.device_ids = device_ids
    
    def train_step(self, data, target, optimizer, criterion):
        """Training step"""
        # Data will be automatically distributed to multiple GPUs
        data, target = data.cuda(), target.cuda()
        
        optimizer.zero_grad()
        output = self.model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        return loss.item()

# Usage example
model = MyModel()
dp_trainer = SimpleDataParallel(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for data, target in dataloader:
    loss = dp_trainer.train_step(data, target, optimizer, criterion)

Distributed Data Parallel (DDP)

1. Basic DDP Implementation

python
def train_ddp(rank, world_size, model_class, train_dataset, num_epochs):
    """DDP training function"""
    # Set up distributed environment
    setup_distributed(rank, world_size)
    
    # Create model and move to GPU
    model = model_class().cuda(rank)
    model = DDP(model, device_ids=[rank])
    
    # Create distributed sampler
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=world_size, rank=rank
    )
    
    # Create data loader
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        sampler=train_sampler,
        num_workers=4,
        pin_memory=True
    )
    
    # Optimizer and loss function
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    # Training loop
    for epoch in range(num_epochs):
        # Set sampler epoch (for data shuffling)
        train_sampler.set_epoch(epoch)
        
        model.train()
        total_loss = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(rank), target.cuda(rank)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 100 == 0 and rank == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        # Only print and save in main process
        if rank == 0:
            avg_loss = total_loss / len(train_loader)
            print(f'Epoch {epoch} completed, Average Loss: {avg_loss:.4f}')
            
            # Save model
            torch.save(model.module.state_dict(), f'model_epoch_{epoch}.pth')
    
    # Cleanup
    cleanup_distributed()

# Launch multi-process training
def main():
    world_size = torch.cuda.device_count()
    mp.spawn(
        train_ddp,
        args=(world_size, MyModel, train_dataset, 10),
        nprocs=world_size,
        join=True
    )

if __name__ == '__main__':
    main()

Model Parallelism

1. Pipeline Parallelism

python
class PipelineParallelModel(nn.Module):
    def __init__(self, num_layers, hidden_size, num_devices):
        super(PipelineParallelModel, self).__init__()
        self.num_devices = num_devices
        self.layers_per_device = num_layers // num_devices
        
        # Distribute layers to different devices
        self.device_layers = nn.ModuleList()
        for device_id in range(num_devices):
            device_layers = nn.ModuleList()
            for _ in range(self.layers_per_device):
                device_layers.append(
                    nn.Linear(hidden_size, hidden_size).to(f'cuda:{device_id}')
                )
            self.device_layers.append(device_layers)
    
    def forward(self, x):
        # Pipeline forward pass
        for device_id in range(self.num_devices):
            x = x.to(f'cuda:{device_id}')
            for layer in self.device_layers[device_id]:
                x = torch.relu(layer(x))
        
        return x

# Use pipeline parallelism
pipeline_model = PipelineParallelModel(
    num_layers=12, hidden_size=512, num_devices=4
)

Mixed Precision Distributed Training

1. DDP Training with AMP

python
from torch.cuda.amp import GradScaler, autocast

class AMPDistributedTrainer(DistributedTrainer):
    def __init__(self, model, train_dataset, val_dataset, config):
        super().__init__(model, train_dataset, val_dataset, config)
        self.scaler = GradScaler()
    
    def train_epoch(self):
        """Training epoch with mixed precision"""
        self.model.train()
        
        if hasattr(self.train_loader.sampler, 'set_epoch'):
            self.train_loader.sampler.set_epoch(self.current_epoch)
        
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            
            # Use autocast for forward pass
            with autocast():
                output = self.model(data)
                loss = self.criterion(output, target)
            
            # Scale loss and backpropagate
            self.scaler.scale(loss).backward()
            
            # Gradient clipping
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            # Update parameters
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            # Statistics
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
        
        avg_loss = self._reduce_metric(total_loss / len(self.train_loader))
        accuracy = self._reduce_metric(correct / total)
        
        return avg_loss, accuracy

Launch Scripts

1. Single-Machine Multi-GPU Launch

bash
#!/bin/bash
# launch_single_node.sh

export CUDA_VISIBLE_DEVICES=0,1,2,3
export MASTER_ADDR=localhost
export MASTER_PORT=12355

python -m torch.distributed.launch \
    --nproc_per_node=4 \
    --master_port=12355 \
    train_distributed.py \
    --batch_size=32 \
    --learning_rate=0.001 \
    --num_epochs=100

2. Multi-Machine Multi-GPU Launch

bash
#!/bin/bash
# launch_multi_node.sh

# Node 0 (master node)
export CUDA_VISIBLE_DEVICES=0,1,2,3
export MASTER_ADDR=192.168.1.100
export MASTER_PORT=12355

python -m torch.distributed.launch \
    --nproc_per_node=4 \
    --nnodes=2 \
    --node_rank=0 \
    --master_addr=192.168.1.100 \
    --master_port=12355 \
    train_distributed.py

# Node 1
export CUDA_VISIBLE_DEVICES=0,1,2,3
export MASTER_ADDR=192.168.1.100
export MASTER_PORT=12355

python -m torch.distributed.launch \
    --nproc_per_node=4 \
    --nnodes=2 \
    --node_rank=1 \
    --master_addr=192.168.1.100 \
    --master_port=12355 \
    train_distributed.py

Summary

PyTorch distributed training covers multiple parallel strategies:

  1. Data Parallelism: DataParallel and DistributedDataParallel
  2. Model Parallelism: Pipeline parallelism and tensor parallelism
  3. Mixed Precision: Distributed training combined with AMP
  4. Large-Scale Optimization: Gradient accumulation, dynamic loss scaling
  5. Deployment Solutions: Single-machine multi-GPU, multi-machine multi-GPU, cluster deployment
  6. Performance Tuning: Communication optimization, computation optimization, monitoring and debugging

Mastering distributed training techniques will help you handle large-scale deep learning tasks and significantly improve training efficiency!

Content is for learning and research only.