Skip to content

PyTorch Image Classification Project

Project Overview

This chapter will demonstrate how to build an end-to-end deep learning solution using a complete image classification project. We will use the CIFAR-10 dataset to build a classifier capable of recognizing 10 different objects.

Project Structure

image_classification/
├── data/                   # Data directory
├── models/                 # Model definitions
│   ├── __init__.py
│   ├── resnet.py
│   └── densenet.py
├── utils/                  # Utility functions
│   ├── __init__.py
│   ├── data_loader.py
│   ├── transforms.py
│   └── metrics.py
├── configs/                # Configuration files
│   └── config.yaml
├── checkpoints/            # Model checkpoints
├── logs/                   # Training logs
├── train.py               # Training script
├── test.py                # Test script
├── inference.py           # Inference script
└── requirements.txt       # Dependencies

Data Preparation

1. Data Loading and Preprocessing

python
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np

class CIFAR10DataModule:
    def __init__(self, data_dir='./data', batch_size=128, num_workers=4):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        # CIFAR-10 class names
        self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                       'dog', 'frog', 'horse', 'ship', 'truck']
        
        # Data statistics
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2023, 0.1994, 0.2010)
        
        self.setup_transforms()
    
    def setup_transforms(self):
        """Set up data transforms"""
        # Data augmentation during training
        self.train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std),
            transforms.RandomErasing(p=0.1)  # Random erasing
        ])
        
        # Transforms for validation and testing
        self.val_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])
    
    def prepare_data(self):
        """Download data"""
        torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True)
        torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        """Set up datasets"""
        if stage == 'fit' or stage is None:
            # Training set
            full_train = torchvision.datasets.CIFAR10(
                root=self.data_dir, train=True, transform=self.train_transform
            )
            
            # Split training and validation sets
            train_size = int(0.9 * len(full_train))
            val_size = len(full_train) - train_size
            self.train_dataset, self.val_dataset = random_split(
                full_train, [train_size, val_size]
            )
            
            # Set different transforms for validation set
            self.val_dataset.dataset = torchvision.datasets.CIFAR10(
                root=self.data_dir, train=True, transform=self.val_transform
            )
        
        if stage == 'test' or stage is None:
            self.test_dataset = torchvision.datasets.CIFAR10(
                root=self.data_dir, train=False, transform=self.val_transform
            )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers,
            pin_memory=True
        )

# Create data module
data_module = CIFAR10DataModule(batch_size=128)
data_module.prepare_data()
data_module.setup()

Model Definition

1. Improved ResNet Model

python
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1, dropout_rate=0.0):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.dropout1 = nn.Dropout2d(dropout_rate)
        
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.dropout2 = nn.Dropout2d(dropout_rate)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
    
    def forward(self, x):
        out = self.dropout1(F.relu(self.bn1(self.conv1(x))))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.dropout2(F.relu(out))
        return out

class ImprovedResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, dropout_rate=0.1):
        super(ImprovedResNet, self).__init__()
        self.in_planes = 64
        self.dropout_rate = dropout_rate
        
        # Initial convolution layer
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        # ResNet layers
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        # Classification head
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
        # Weight initialization
        self._initialize_weights()
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, self.dropout_rate))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.dropout(out)
        out = self.fc(out)
        return out

def create_resnet18(num_classes=10, dropout_rate=0.1):
    return ImprovedResNet(BasicBlock, [2, 2, 2, 2], num_classes, dropout_rate)

Training Framework

1. Trainer Class

python
import os
import time
from collections import defaultdict
import torch.optim as optim
from torch.optim.lr_scheduler import *

class ImageClassificationTrainer:
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        # Device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        
        # Optimizer
        self.optimizer = self._create_optimizer()
        
        # Learning rate scheduler
        self.scheduler = self._create_scheduler()
        
        # Training history
        self.history = defaultdict(list)
        
        # Best model tracking
        self.best_val_acc = 0.0
        self.best_epoch = 0
        
        # Create save directory
        os.makedirs(config['save_dir'], exist_ok=True)
    
    def _create_optimizer(self):
        if self.config['optimizer'] == 'adamw':
            return optim.AdamW(
                self.model.parameters(),
                lr=self.config['learning_rate'],
                weight_decay=self.config['weight_decay'],
                betas=(0.9, 0.999)
            )
        elif self.config['optimizer'] == 'sgd':
            return optim.SGD(
                self.model.parameters(),
                lr=self.config['learning_rate'],
                momentum=0.9,
                weight_decay=self.config['weight_decay'],
                nesterov=True
            )
    
    def _create_scheduler(self):
        if self.config['scheduler'] == 'cosine':
            return CosineAnnealingLR(
                self.optimizer, 
                T_max=self.config['epochs'],
                eta_min=1e-6
            )
        elif self.config['scheduler'] == 'step':
            return StepLR(
                self.optimizer,
                step_size=self.config['step_size'],
                gamma=0.1
            )
        elif self.config['scheduler'] == 'plateau':
            return ReduceLROnPlateau(
                self.optimizer,
                mode='max',
                factor=0.5,
                patience=10,
                verbose=True
            )
    
    def train_epoch(self, epoch):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            
            # Forward pass
            output = self.model(data)
            loss = self.criterion(output, target)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
            
            # Print progress
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}/{len(self.train_loader)}, '
                      f'Loss: {loss.item():.4f}, Acc: {100.*correct/total:.2f}%')
        
        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = 100. * correct / total
        
        return epoch_loss, epoch_acc
    
    def validate_epoch(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in self.val_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                
                running_loss += loss.item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += target.size(0)
        
        val_loss = running_loss / len(self.val_loader)
        val_acc = 100. * correct / total
        
        return val_loss, val_acc
    
    def train(self):
        print(f"Starting training, device: {self.device}")
        print(f"Model parameter count: {sum(p.numel() for p in self.model.parameters()):,}")
        
        start_time = time.time()
        
        for epoch in range(self.config['epochs']):
            epoch_start = time.time()
            
            # Training
            train_loss, train_acc = self.train_epoch(epoch)
            
            # Validation
            val_loss, val_acc = self.validate_epoch()
            
            # Update learning rate
            if isinstance(self.scheduler, ReduceLROnPlateau):
                self.scheduler.step(val_acc)
            else:
                self.scheduler.step()
            
            # Record history
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
            self.history['lr'].append(self.optimizer.param_groups[0]['lr'])
            
            epoch_time = time.time() - epoch_start
            
            # Print results
            print(f'Epoch {epoch+1}/{self.config["epochs"]}:')
            print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
            print(f'  LR: {self.optimizer.param_groups[0]["lr"]:.6f}')
            print(f'  Time: {epoch_time:.2f}s')
            
            # Save best model
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.best_epoch = epoch
                self.save_checkpoint(epoch, is_best=True)
                print(f'  ✓ New best model! Validation accuracy: {val_acc:.2f}%')
            
            # Periodic saving
            if (epoch + 1) % 20 == 0:
                self.save_checkpoint(epoch)
            
            print('-' * 60)
        
        total_time = time.time() - start_time
        print(f'Training complete! Total time: {total_time/3600:.2f} hours')
        print(f'Best validation accuracy: {self.best_val_acc:.2f}% (Epoch {self.best_epoch+1})')
        
        return self.history
    
    def save_checkpoint(self, epoch, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_val_acc': self.best_val_acc,
            'history': dict(self.history)
        }
        
        if is_best:
            torch.save(checkpoint, os.path.join(self.config['save_dir'], 'best_model.pth'))
        
        torch.save(checkpoint, os.path.join(self.config['save_dir'], f'checkpoint_epoch_{epoch+1}.pth'))

Model Evaluation

1. Detailed Evaluation

python
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

def evaluate_model(model, test_loader, device, classes):
    """Detailed model performance evaluation"""
    model.eval()
    
    all_preds = []
    all_targets = []
    all_probs = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            # Get predictions and probabilities
            probs = F.softmax(output, dim=1)
            pred = output.argmax(dim=1)
            
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # Compute accuracy
    accuracy = (np.array(all_preds) == np.array(all_targets)).mean()
    
    # Classification report
    report = classification_report(all_targets, all_preds, target_names=classes)
    
    # Confusion matrix
    cm = confusion_matrix(all_targets, all_preds)
    
    print(f"Test accuracy: {accuracy:.4f}")
    print("\nClassification report:")
    print(report)
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, yticklabels=classes)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    return accuracy, report, cm, all_probs

Model Inference

1. Single Image Inference

python
def predict_single_image(model, image_path, transform, classes, device):
    """Predict single image"""
    from PIL import Image
    
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = F.softmax(output, dim=1)
        predicted_class = output.argmax(dim=1).item()
        confidence = probabilities[0][predicted_class].item()
    
    # Get top-5 predictions
    top5_prob, top5_idx = torch.topk(probabilities, 5)
    top5_classes = [classes[idx] for idx in top5_idx[0]]
    top5_probs = top5_prob[0].tolist()
    
    return {
        'predicted_class': classes[predicted_class],
        'confidence': confidence,
        'top5_predictions': list(zip(top5_classes, top5_probs))
    }

Summary

This chapter demonstrated a complete image classification project:

  1. Project Structure: How to organize deep learning project code structure
  2. Data Processing: Complete workflow of data loading, preprocessing, and augmentation
  3. Model Design: Improved ResNet architecture and ensemble methods
  4. Training Framework: Complete training, validation, and saving workflow
  5. Model Evaluation: Multiple evaluation metrics and error analysis methods
  6. Model Inference: Single and batch inference implementation

This project template can serve as a foundation for other image classification tasks, adaptable to different application scenarios by modifying data loading and model structure.

Content is for learning and research only.