Skip to content

PyTorch Data Processing

Data Processing Overview

In deep learning, data processing is a crucial step. PyTorch provides powerful data processing tools, mainly including:

  • torch.utils.data.Dataset: Dataset abstraction class
  • torch.utils.data.DataLoader: Data loader
  • torchvision.transforms: Data transformation tools

Dataset Class

1. Custom Dataset

python
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image
import os

class CustomDataset(Dataset):
    def __init__(self, data_file, transform=None):
        """
        Custom dataset
        Args:
            data_file: Data file path
            transform: Data transformation
        """
        self.data = pd.read_csv(data_file)
        self.transform = transform
    
    def __len__(self):
        """Return dataset size"""
        return len(self.data)
    
    def __getitem__(self, idx):
        """Get single sample"""
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Get data
        sample = self.data.iloc[idx]
        image_path = sample['image_path']
        label = sample['label']
        
        # Load image
        image = Image.open(image_path)
        
        # Apply transformation
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Usage example
# dataset = CustomDataset('data.csv', transform=transforms.ToTensor())

2. Image Dataset Example

python
class ImageDataset(Dataset):
    def __init__(self, root_dir, annotations_file, transform=None):
        self.root_dir = root_dir
        self.annotations = pd.read_csv(annotations_file)
        self.transform = transform
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 0])
        image = Image.open(img_path).convert('RGB')
        label = self.annotations.iloc[idx, 1]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Text dataset example
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # Text encoding
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

3. In-Memory Dataset

python
class MemoryDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        """
        In-memory dataset (suitable for small datasets)
        """
        self.data = data
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample, label

# Create example data
data = torch.randn(1000, 3, 32, 32)  # 1000 32x32 RGB images
labels = torch.randint(0, 10, (1000,))  # 10 class labels

dataset = MemoryDataset(data, labels)

DataLoader

1. Basic Usage

python
from torch.utils.data import DataLoader

# Create data loader
dataloader = DataLoader(
    dataset,
    batch_size=32,      # Batch size
    shuffle=True,       # Whether to shuffle data
    num_workers=4,      # Number of parallel loading processes
    pin_memory=True,    # Whether to pin memory for CUDA
    drop_last=True      # Whether to drop last incomplete batch
)

# Iterate through data
for batch_idx, (data, target) in enumerate(dataloader):
    print(f"Batch {batch_idx}: data shape {data.shape}, label shape {target.shape}")
    if batch_idx >= 2:  # Only show first 3 batches
        break

2. Custom Collate Function

python
def custom_collate_fn(batch):
    """
    Custom batch collation function
    """
    # Separate data and labels
    data = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    
    # Handle variable-length sequences
    # Assume data is variable-length sequences
    lengths = [len(seq) for seq in data]
    max_length = max(lengths)
    
    # Pad sequences
    padded_data = []
    for seq in data:
        padded = torch.zeros(max_length, seq.size(-1))
        padded[:len(seq)] = seq
        padded_data.append(padded)
    
    return torch.stack(padded_data), torch.tensor(labels), torch.tensor(lengths)

# Use custom collate function
dataloader = DataLoader(
    dataset,
    batch_size=32,
    collate_fn=custom_collate_fn
)

3. Distributed Data Loading

python
from torch.utils.data.distributed import DistributedSampler

# Data loading during distributed training
def create_distributed_dataloader(dataset, batch_size, world_size, rank):
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=4,
        pin_memory=True
    )
    
    return dataloader, sampler

Data Transforms (Transforms)

1. Image Transforms

python
import torchvision.transforms as transforms
from torchvision.transforms import functional as F

# Basic transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),          # Resize
    transforms.RandomHorizontalFlip(0.5),   # Random horizontal flip
    transforms.RandomRotation(10),          # Random rotation
    transforms.ColorJitter(                 # Color jitter
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.1
    ),
    transforms.ToTensor(),                  # Convert to tensor
    transforms.Normalize(                   # Normalize
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Advanced transforms
advanced_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3)
    ], p=0.3),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.2)  # Random erasing
])

2. Custom Transforms

python
class AddGaussianNoise:
    """Add Gaussian noise"""
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'

class Cutout:
    """Random occlusion"""
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img

# Use custom transforms
custom_transform = transforms.Compose([
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.1),
    Cutout(n_holes=1, length=16)
])

3. Data Augmentation Strategy

python
# Training data augmentation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation data processing
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Test-time augmentation (TTA)
tta_transforms = [
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
]

Built-in Datasets

1. Computer Vision Datasets

python
import torchvision.datasets as datasets

# CIFAR-10
cifar10_train = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=train_transform
)

cifar10_test = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=val_transform
)

# ImageNet
imagenet_train = datasets.ImageNet(
    root='./data/imagenet',
    split='train',
    transform=train_transform
)

# MNIST
mnist_train = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
)

2. Natural Language Processing Datasets

python
import torchtext.datasets as text_datasets
from torchtext.data.utils import get_tokenizer

# IMDB movie review dataset
train_iter, test_iter = text_datasets.IMDB(split=('train', 'test'))

# Process text data
tokenizer = get_tokenizer('basic_english')

def process_text(text_iter):
    data = []
    for label, text in text_iter:
        tokens = tokenizer(text)
        data.append((tokens, label))
    return data

train_data = process_text(train_iter)

Data Preprocessing Techniques

1. Data Standardization

python
def compute_mean_std(dataset):
    """Compute mean and standard deviation of dataset"""
    dataloader = DataLoader(dataset, batch_size=100, shuffle=False)
    
    mean = torch.zeros(3)
    std = torch.zeros(3)
    total_samples = 0
    
    for data, _ in dataloader:
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        total_samples += batch_samples
    
    mean /= total_samples
    std /= total_samples
    
    return mean, std

# Usage example
# mean, std = compute_mean_std(dataset)
# print(f"Mean: {mean}, Std: {std}")

2. Data Balancing

python
from torch.utils.data import WeightedRandomSampler
from collections import Counter

def create_balanced_sampler(dataset):
    """Create balanced sampler"""
    # Count samples per class
    labels = [dataset[i][1] for i in range(len(dataset))]
    class_counts = Counter(labels)
    
    # Compute weights
    total_samples = len(labels)
    class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
    
    # Assign weight to each sample
    sample_weights = [class_weights[label] for label in labels]
    
    # Create sampler
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    return sampler

# Use balanced sampler
# sampler = create_balanced_sampler(dataset)
# dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

3. Data Splitting

python
from torch.utils.data import random_split

def split_dataset(dataset, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    """Split dataset"""
    assert train_ratio + val_ratio + test_ratio == 1.0
    
    total_size = len(dataset)
    train_size = int(train_ratio * total_size)
    val_size = int(val_ratio * total_size)
    test_size = total_size - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [train_size, val_size, test_size]
    )
    
    return train_dataset, val_dataset, test_dataset

# Usage example
# train_data, val_data, test_data = split_dataset(dataset)

Advanced Data Processing

1. Multi-Modal Data Processing

python
class MultiModalDataset(Dataset):
    def __init__(self, image_paths, texts, labels, image_transform=None, text_tokenizer=None):
        self.image_paths = image_paths
        self.texts = texts
        self.labels = labels
        self.image_transform = image_transform
        self.text_tokenizer = text_tokenizer
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Load image
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.image_transform:
            image = self.image_transform(image)
        
        # Process text
        text = self.texts[idx]
        if self.text_tokenizer:
            text_tokens = self.text_tokenizer(text)
        else:
            text_tokens = text
        
        label = self.labels[idx]
        
        return {
            'image': image,
            'text': text_tokens,
            'label': label
        }

2. Online Data Augmentation

python
class OnlineAugmentation:
    def __init__(self, transforms_list, probabilities):
        self.transforms = transforms_list
        self.probs = probabilities
    
    def __call__(self, image):
        for transform, prob in zip(self.transforms, self.probs):
            if torch.rand(1) < prob:
                image = transform(image)
        return image

# Use online augmentation
online_aug = OnlineAugmentation(
    transforms_list=[
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1)
    ],
    probabilities=[0.5, 0.3, 0.4]
)

3. Caching Mechanism

python
class CachedDataset(Dataset):
    def __init__(self, dataset, cache_size=1000):
        self.dataset = dataset
        self.cache = {}
        self.cache_size = cache_size
        self.access_count = {}
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        if idx in self.cache:
            self.access_count[idx] += 1
            return self.cache[idx]
        
        # Load data
        data = self.dataset[idx]
        
        # Cache management
        if len(self.cache) >= self.cache_size:
            # Remove least accessed item
            lru_idx = min(self.access_count, key=self.access_count.get)
            del self.cache[lru_idx]
            del self.access_count[lru_idx]
        
        # Add to cache
        self.cache[idx] = data
        self.access_count[idx] = 1
        
        return data

Performance Optimization

1. Data Loading Optimization

python
# Optimize data loading performance
def create_optimized_dataloader(dataset, batch_size, num_workers=None):
    if num_workers is None:
        num_workers = min(8, os.cpu_count())
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=True,  # Keep worker processes
        prefetch_factor=2,        # Prefetch factor
    )
    
    return dataloader

2. Memory Mapping

python
import mmap

class MemoryMappedDataset(Dataset):
    def __init__(self, data_file, index_file):
        # Use memory mapping to read large files
        self.data_file = open(data_file, 'rb')
        self.data_mmap = mmap.mmap(self.data_file.fileno(), 0, access=mmap.ACCESS_READ)
        
        # Load indices
        with open(index_file, 'r') as f:
            self.indices = [int(line.strip()) for line in f]
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        offset = self.indices[idx]
        # Read data from memory mapping
        self.data_mmap.seek(offset)
        # Read and parse data logic
        # ...
        pass
    
    def __del__(self):
        if hasattr(self, 'data_mmap'):
            self.data_mmap.close()
        if hasattr(self, 'data_file'):
            self.data_file.close()

Practical Application Example

1. Complete Image Classification Data Pipeline

python
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets

def create_image_dataloaders(data_dir, batch_size=32, num_workers=4):
    # Data transforms
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = datasets.ImageFolder(
        root=f'{data_dir}/train',
        transform=train_transform
    )
    
    val_dataset = datasets.ImageFolder(
        root=f'{data_dir}/val',
        transform=val_transform
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader, len(train_dataset.classes)

# Usage example
# train_loader, val_loader, num_classes = create_image_dataloaders('./data')

Summary

Data processing is the foundation of deep learning projects, and mastering PyTorch's data processing tools is crucial:

  1. Dataset Class: Learn to create custom datasets and handle different types of data
  2. DataLoader: Master batch loading, parallel processing, data shuffling, etc.
  3. Data Transforms: Proficiently use built-in and custom transforms for data augmentation
  4. Performance Optimization: Understand caching, memory mapping, parallel loading, etc.
  5. Practical Applications: Be able to build complete data processing pipelines

Good data processing not only improves model performance but also significantly speeds up the training process!

Content is for learning and research only.