PyTorch Image Classification Project
Project Overview
This chapter will demonstrate how to build an end-to-end deep learning solution using a complete image classification project. We will use the CIFAR-10 dataset to build a classifier capable of recognizing 10 different objects.
Project Structure
image_classification/
├── data/ # Data directory
├── models/ # Model definitions
│ ├── __init__.py
│ ├── resnet.py
│ └── densenet.py
├── utils/ # Utility functions
│ ├── __init__.py
│ ├── data_loader.py
│ ├── transforms.py
│ └── metrics.py
├── configs/ # Configuration files
│ └── config.yaml
├── checkpoints/ # Model checkpoints
├── logs/ # Training logs
├── train.py # Training script
├── test.py # Test script
├── inference.py # Inference script
└── requirements.txt # DependenciesData Preparation
1. Data Loading and Preprocessing
python
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
class CIFAR10DataModule:
def __init__(self, data_dir='./data', batch_size=128, num_workers=4):
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
# CIFAR-10 class names
self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# Data statistics
self.mean = (0.4914, 0.4822, 0.4465)
self.std = (0.2023, 0.1994, 0.2010)
self.setup_transforms()
def setup_transforms(self):
"""Set up data transforms"""
# Data augmentation during training
self.train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std),
transforms.RandomErasing(p=0.1) # Random erasing
])
# Transforms for validation and testing
self.val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std)
])
def prepare_data(self):
"""Download data"""
torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True)
torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True)
def setup(self, stage=None):
"""Set up datasets"""
if stage == 'fit' or stage is None:
# Training set
full_train = torchvision.datasets.CIFAR10(
root=self.data_dir, train=True, transform=self.train_transform
)
# Split training and validation sets
train_size = int(0.9 * len(full_train))
val_size = len(full_train) - train_size
self.train_dataset, self.val_dataset = random_split(
full_train, [train_size, val_size]
)
# Set different transforms for validation set
self.val_dataset.dataset = torchvision.datasets.CIFAR10(
root=self.data_dir, train=True, transform=self.val_transform
)
if stage == 'test' or stage is None:
self.test_dataset = torchvision.datasets.CIFAR10(
root=self.data_dir, train=False, transform=self.val_transform
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True
)
# Create data module
data_module = CIFAR10DataModule(batch_size=128)
data_module.prepare_data()
data_module.setup()Model Definition
1. Improved ResNet Model
python
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, dropout_rate=0.0):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.dropout1 = nn.Dropout2d(dropout_rate)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.dropout2 = nn.Dropout2d(dropout_rate)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = self.dropout1(F.relu(self.bn1(self.conv1(x))))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = self.dropout2(F.relu(out))
return out
class ImprovedResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10, dropout_rate=0.1):
super(ImprovedResNet, self).__init__()
self.in_planes = 64
self.dropout_rate = dropout_rate
# Initial convolution layer
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
# ResNet layers
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
# Classification head
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(0.5)
self.fc = nn.Linear(512 * block.expansion, num_classes)
# Weight initialization
self._initialize_weights()
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride, self.dropout_rate))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avgpool(out)
out = torch.flatten(out, 1)
out = self.dropout(out)
out = self.fc(out)
return out
def create_resnet18(num_classes=10, dropout_rate=0.1):
return ImprovedResNet(BasicBlock, [2, 2, 2, 2], num_classes, dropout_rate)Training Framework
1. Trainer Class
python
import os
import time
from collections import defaultdict
import torch.optim as optim
from torch.optim.lr_scheduler import *
class ImageClassificationTrainer:
def __init__(self, model, train_loader, val_loader, config):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
# Device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
# Loss function
self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# Optimizer
self.optimizer = self._create_optimizer()
# Learning rate scheduler
self.scheduler = self._create_scheduler()
# Training history
self.history = defaultdict(list)
# Best model tracking
self.best_val_acc = 0.0
self.best_epoch = 0
# Create save directory
os.makedirs(config['save_dir'], exist_ok=True)
def _create_optimizer(self):
if self.config['optimizer'] == 'adamw':
return optim.AdamW(
self.model.parameters(),
lr=self.config['learning_rate'],
weight_decay=self.config['weight_decay'],
betas=(0.9, 0.999)
)
elif self.config['optimizer'] == 'sgd':
return optim.SGD(
self.model.parameters(),
lr=self.config['learning_rate'],
momentum=0.9,
weight_decay=self.config['weight_decay'],
nesterov=True
)
def _create_scheduler(self):
if self.config['scheduler'] == 'cosine':
return CosineAnnealingLR(
self.optimizer,
T_max=self.config['epochs'],
eta_min=1e-6
)
elif self.config['scheduler'] == 'step':
return StepLR(
self.optimizer,
step_size=self.config['step_size'],
gamma=0.1
)
elif self.config['scheduler'] == 'plateau':
return ReduceLROnPlateau(
self.optimizer,
mode='max',
factor=0.5,
patience=10,
verbose=True
)
def train_epoch(self, epoch):
self.model.train()
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(self.train_loader):
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
# Forward pass
output = self.model(data)
loss = self.criterion(output, target)
# Backward pass
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# Statistics
running_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total += target.size(0)
# Print progress
if batch_idx % 100 == 0:
print(f'Epoch {epoch}, Batch {batch_idx}/{len(self.train_loader)}, '
f'Loss: {loss.item():.4f}, Acc: {100.*correct/total:.2f}%')
epoch_loss = running_loss / len(self.train_loader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
def validate_epoch(self):
self.model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for data, target in self.val_loader:
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
loss = self.criterion(output, target)
running_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total += target.size(0)
val_loss = running_loss / len(self.val_loader)
val_acc = 100. * correct / total
return val_loss, val_acc
def train(self):
print(f"Starting training, device: {self.device}")
print(f"Model parameter count: {sum(p.numel() for p in self.model.parameters()):,}")
start_time = time.time()
for epoch in range(self.config['epochs']):
epoch_start = time.time()
# Training
train_loss, train_acc = self.train_epoch(epoch)
# Validation
val_loss, val_acc = self.validate_epoch()
# Update learning rate
if isinstance(self.scheduler, ReduceLROnPlateau):
self.scheduler.step(val_acc)
else:
self.scheduler.step()
# Record history
self.history['train_loss'].append(train_loss)
self.history['train_acc'].append(train_acc)
self.history['val_loss'].append(val_loss)
self.history['val_acc'].append(val_acc)
self.history['lr'].append(self.optimizer.param_groups[0]['lr'])
epoch_time = time.time() - epoch_start
# Print results
print(f'Epoch {epoch+1}/{self.config["epochs"]}:')
print(f' Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
print(f' Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
print(f' LR: {self.optimizer.param_groups[0]["lr"]:.6f}')
print(f' Time: {epoch_time:.2f}s')
# Save best model
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
self.best_epoch = epoch
self.save_checkpoint(epoch, is_best=True)
print(f' ✓ New best model! Validation accuracy: {val_acc:.2f}%')
# Periodic saving
if (epoch + 1) % 20 == 0:
self.save_checkpoint(epoch)
print('-' * 60)
total_time = time.time() - start_time
print(f'Training complete! Total time: {total_time/3600:.2f} hours')
print(f'Best validation accuracy: {self.best_val_acc:.2f}% (Epoch {self.best_epoch+1})')
return self.history
def save_checkpoint(self, epoch, is_best=False):
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'best_val_acc': self.best_val_acc,
'history': dict(self.history)
}
if is_best:
torch.save(checkpoint, os.path.join(self.config['save_dir'], 'best_model.pth'))
torch.save(checkpoint, os.path.join(self.config['save_dir'], f'checkpoint_epoch_{epoch+1}.pth'))Model Evaluation
1. Detailed Evaluation
python
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
def evaluate_model(model, test_loader, device, classes):
"""Detailed model performance evaluation"""
model.eval()
all_preds = []
all_targets = []
all_probs = []
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# Get predictions and probabilities
probs = F.softmax(output, dim=1)
pred = output.argmax(dim=1)
all_preds.extend(pred.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
# Compute accuracy
accuracy = (np.array(all_preds) == np.array(all_targets)).mean()
# Classification report
report = classification_report(all_targets, all_preds, target_names=classes)
# Confusion matrix
cm = confusion_matrix(all_targets, all_preds)
print(f"Test accuracy: {accuracy:.4f}")
print("\nClassification report:")
print(report)
# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=classes, yticklabels=classes)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()
return accuracy, report, cm, all_probsModel Inference
1. Single Image Inference
python
def predict_single_image(model, image_path, transform, classes, device):
"""Predict single image"""
from PIL import Image
# Load and preprocess image
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
output = model(input_tensor)
probabilities = F.softmax(output, dim=1)
predicted_class = output.argmax(dim=1).item()
confidence = probabilities[0][predicted_class].item()
# Get top-5 predictions
top5_prob, top5_idx = torch.topk(probabilities, 5)
top5_classes = [classes[idx] for idx in top5_idx[0]]
top5_probs = top5_prob[0].tolist()
return {
'predicted_class': classes[predicted_class],
'confidence': confidence,
'top5_predictions': list(zip(top5_classes, top5_probs))
}Summary
This chapter demonstrated a complete image classification project:
- Project Structure: How to organize deep learning project code structure
- Data Processing: Complete workflow of data loading, preprocessing, and augmentation
- Model Design: Improved ResNet architecture and ensemble methods
- Training Framework: Complete training, validation, and saving workflow
- Model Evaluation: Multiple evaluation metrics and error analysis methods
- Model Inference: Single and batch inference implementation
This project template can serve as a foundation for other image classification tasks, adaptable to different application scenarios by modifying data loading and model structure.