PyTorch Data Processing
Data Processing Overview
In deep learning, data processing is a crucial step. PyTorch provides powerful data processing tools, mainly including:
torch.utils.data.Dataset: Dataset abstraction classtorch.utils.data.DataLoader: Data loadertorchvision.transforms: Data transformation tools
Dataset Class
1. Custom Dataset
python
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image
import os
class CustomDataset(Dataset):
def __init__(self, data_file, transform=None):
"""
Custom dataset
Args:
data_file: Data file path
transform: Data transformation
"""
self.data = pd.read_csv(data_file)
self.transform = transform
def __len__(self):
"""Return dataset size"""
return len(self.data)
def __getitem__(self, idx):
"""Get single sample"""
if torch.is_tensor(idx):
idx = idx.tolist()
# Get data
sample = self.data.iloc[idx]
image_path = sample['image_path']
label = sample['label']
# Load image
image = Image.open(image_path)
# Apply transformation
if self.transform:
image = self.transform(image)
return image, label
# Usage example
# dataset = CustomDataset('data.csv', transform=transforms.ToTensor())2. Image Dataset Example
python
class ImageDataset(Dataset):
def __init__(self, root_dir, annotations_file, transform=None):
self.root_dir = root_dir
self.annotations = pd.read_csv(annotations_file)
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 0])
image = Image.open(img_path).convert('RGB')
label = self.annotations.iloc[idx, 1]
if self.transform:
image = self.transform(image)
return image, label
# Text dataset example
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = str(self.texts[idx])
label = self.labels[idx]
# Text encoding
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'label': torch.tensor(label, dtype=torch.long)
}3. In-Memory Dataset
python
class MemoryDataset(Dataset):
def __init__(self, data, labels, transform=None):
"""
In-memory dataset (suitable for small datasets)
"""
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
if self.transform:
sample = self.transform(sample)
return sample, label
# Create example data
data = torch.randn(1000, 3, 32, 32) # 1000 32x32 RGB images
labels = torch.randint(0, 10, (1000,)) # 10 class labels
dataset = MemoryDataset(data, labels)DataLoader
1. Basic Usage
python
from torch.utils.data import DataLoader
# Create data loader
dataloader = DataLoader(
dataset,
batch_size=32, # Batch size
shuffle=True, # Whether to shuffle data
num_workers=4, # Number of parallel loading processes
pin_memory=True, # Whether to pin memory for CUDA
drop_last=True # Whether to drop last incomplete batch
)
# Iterate through data
for batch_idx, (data, target) in enumerate(dataloader):
print(f"Batch {batch_idx}: data shape {data.shape}, label shape {target.shape}")
if batch_idx >= 2: # Only show first 3 batches
break2. Custom Collate Function
python
def custom_collate_fn(batch):
"""
Custom batch collation function
"""
# Separate data and labels
data = [item[0] for item in batch]
labels = [item[1] for item in batch]
# Handle variable-length sequences
# Assume data is variable-length sequences
lengths = [len(seq) for seq in data]
max_length = max(lengths)
# Pad sequences
padded_data = []
for seq in data:
padded = torch.zeros(max_length, seq.size(-1))
padded[:len(seq)] = seq
padded_data.append(padded)
return torch.stack(padded_data), torch.tensor(labels), torch.tensor(lengths)
# Use custom collate function
dataloader = DataLoader(
dataset,
batch_size=32,
collate_fn=custom_collate_fn
)3. Distributed Data Loading
python
from torch.utils.data.distributed import DistributedSampler
# Data loading during distributed training
def create_distributed_dataloader(dataset, batch_size, world_size, rank):
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=4,
pin_memory=True
)
return dataloader, samplerData Transforms (Transforms)
1. Image Transforms
python
import torchvision.transforms as transforms
from torchvision.transforms import functional as F
# Basic transforms
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize
transforms.RandomHorizontalFlip(0.5), # Random horizontal flip
transforms.RandomRotation(10), # Random rotation
transforms.ColorJitter( # Color jitter
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
transforms.ToTensor(), # Convert to tensor
transforms.Normalize( # Normalize
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Advanced transforms
advanced_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomApply([
transforms.GaussianBlur(kernel_size=3)
], p=0.3),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.RandomErasing(p=0.2) # Random erasing
])2. Custom Transforms
python
class AddGaussianNoise:
"""Add Gaussian noise"""
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'
class Cutout:
"""Random occlusion"""
def __init__(self, n_holes, length):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img = img * mask
return img
# Use custom transforms
custom_transform = transforms.Compose([
transforms.ToTensor(),
AddGaussianNoise(0., 0.1),
Cutout(n_holes=1, length=16)
])3. Data Augmentation Strategy
python
# Training data augmentation
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Validation data processing
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Test-time augmentation (TTA)
tta_transforms = [
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(p=1.0),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
]Built-in Datasets
1. Computer Vision Datasets
python
import torchvision.datasets as datasets
# CIFAR-10
cifar10_train = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=train_transform
)
cifar10_test = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=val_transform
)
# ImageNet
imagenet_train = datasets.ImageNet(
root='./data/imagenet',
split='train',
transform=train_transform
)
# MNIST
mnist_train = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
)2. Natural Language Processing Datasets
python
import torchtext.datasets as text_datasets
from torchtext.data.utils import get_tokenizer
# IMDB movie review dataset
train_iter, test_iter = text_datasets.IMDB(split=('train', 'test'))
# Process text data
tokenizer = get_tokenizer('basic_english')
def process_text(text_iter):
data = []
for label, text in text_iter:
tokens = tokenizer(text)
data.append((tokens, label))
return data
train_data = process_text(train_iter)Data Preprocessing Techniques
1. Data Standardization
python
def compute_mean_std(dataset):
"""Compute mean and standard deviation of dataset"""
dataloader = DataLoader(dataset, batch_size=100, shuffle=False)
mean = torch.zeros(3)
std = torch.zeros(3)
total_samples = 0
for data, _ in dataloader:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
total_samples += batch_samples
mean /= total_samples
std /= total_samples
return mean, std
# Usage example
# mean, std = compute_mean_std(dataset)
# print(f"Mean: {mean}, Std: {std}")2. Data Balancing
python
from torch.utils.data import WeightedRandomSampler
from collections import Counter
def create_balanced_sampler(dataset):
"""Create balanced sampler"""
# Count samples per class
labels = [dataset[i][1] for i in range(len(dataset))]
class_counts = Counter(labels)
# Compute weights
total_samples = len(labels)
class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
# Assign weight to each sample
sample_weights = [class_weights[label] for label in labels]
# Create sampler
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
return sampler
# Use balanced sampler
# sampler = create_balanced_sampler(dataset)
# dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)3. Data Splitting
python
from torch.utils.data import random_split
def split_dataset(dataset, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
"""Split dataset"""
assert train_ratio + val_ratio + test_ratio == 1.0
total_size = len(dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
dataset, [train_size, val_size, test_size]
)
return train_dataset, val_dataset, test_dataset
# Usage example
# train_data, val_data, test_data = split_dataset(dataset)Advanced Data Processing
1. Multi-Modal Data Processing
python
class MultiModalDataset(Dataset):
def __init__(self, image_paths, texts, labels, image_transform=None, text_tokenizer=None):
self.image_paths = image_paths
self.texts = texts
self.labels = labels
self.image_transform = image_transform
self.text_tokenizer = text_tokenizer
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
# Load image
image = Image.open(self.image_paths[idx]).convert('RGB')
if self.image_transform:
image = self.image_transform(image)
# Process text
text = self.texts[idx]
if self.text_tokenizer:
text_tokens = self.text_tokenizer(text)
else:
text_tokens = text
label = self.labels[idx]
return {
'image': image,
'text': text_tokens,
'label': label
}2. Online Data Augmentation
python
class OnlineAugmentation:
def __init__(self, transforms_list, probabilities):
self.transforms = transforms_list
self.probs = probabilities
def __call__(self, image):
for transform, prob in zip(self.transforms, self.probs):
if torch.rand(1) < prob:
image = transform(image)
return image
# Use online augmentation
online_aug = OnlineAugmentation(
transforms_list=[
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1)
],
probabilities=[0.5, 0.3, 0.4]
)3. Caching Mechanism
python
class CachedDataset(Dataset):
def __init__(self, dataset, cache_size=1000):
self.dataset = dataset
self.cache = {}
self.cache_size = cache_size
self.access_count = {}
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
if idx in self.cache:
self.access_count[idx] += 1
return self.cache[idx]
# Load data
data = self.dataset[idx]
# Cache management
if len(self.cache) >= self.cache_size:
# Remove least accessed item
lru_idx = min(self.access_count, key=self.access_count.get)
del self.cache[lru_idx]
del self.access_count[lru_idx]
# Add to cache
self.cache[idx] = data
self.access_count[idx] = 1
return dataPerformance Optimization
1. Data Loading Optimization
python
# Optimize data loading performance
def create_optimized_dataloader(dataset, batch_size, num_workers=None):
if num_workers is None:
num_workers = min(8, os.cpu_count())
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
persistent_workers=True, # Keep worker processes
prefetch_factor=2, # Prefetch factor
)
return dataloader2. Memory Mapping
python
import mmap
class MemoryMappedDataset(Dataset):
def __init__(self, data_file, index_file):
# Use memory mapping to read large files
self.data_file = open(data_file, 'rb')
self.data_mmap = mmap.mmap(self.data_file.fileno(), 0, access=mmap.ACCESS_READ)
# Load indices
with open(index_file, 'r') as f:
self.indices = [int(line.strip()) for line in f]
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
offset = self.indices[idx]
# Read data from memory mapping
self.data_mmap.seek(offset)
# Read and parse data logic
# ...
pass
def __del__(self):
if hasattr(self, 'data_mmap'):
self.data_mmap.close()
if hasattr(self, 'data_file'):
self.data_file.close()Practical Application Example
1. Complete Image Classification Data Pipeline
python
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
def create_image_dataloaders(data_dir, batch_size=32, num_workers=4):
# Data transforms
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Create datasets
train_dataset = datasets.ImageFolder(
root=f'{data_dir}/train',
transform=train_transform
)
val_dataset = datasets.ImageFolder(
root=f'{data_dir}/val',
transform=val_transform
)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
return train_loader, val_loader, len(train_dataset.classes)
# Usage example
# train_loader, val_loader, num_classes = create_image_dataloaders('./data')Summary
Data processing is the foundation of deep learning projects, and mastering PyTorch's data processing tools is crucial:
- Dataset Class: Learn to create custom datasets and handle different types of data
- DataLoader: Master batch loading, parallel processing, data shuffling, etc.
- Data Transforms: Proficiently use built-in and custom transforms for data augmentation
- Performance Optimization: Understand caching, memory mapping, parallel loading, etc.
- Practical Applications: Be able to build complete data processing pipelines
Good data processing not only improves model performance but also significantly speeds up the training process!