#PyTorch Data Processing
#Data Processing Overview
In deep learning, data processing is a crucial step. PyTorch provides powerful data processing tools, mainly including:
torch.utils.data.Dataset: Dataset abstraction classtorch.utils.data.DataLoader: Data loadertorchvision.transforms: Data transformation tools
#Dataset Class
#1. Custom Dataset
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image
import os
class CustomDataset(Dataset):
def __init__(self, data_file, transform=None):
"""
Custom dataset
Args:
data_file: Data file path
transform: Data transformation
"""
self.data = pd.read_csv(data_file)
self.transform = transform
def __len__(self):
"""Return dataset size"""
return len(self.data)
def __getitem__(self, idx):
"""Get single sample"""
if torch.is_tensor(idx):
idx = idx.tolist()
# Get data
sample = self.data.iloc[idx]
image_path = sample['image_path']
label = sample['label']
# Load image
image = Image.open(image_path)
# Apply transformation
if self.transform:
image = self.transform(image)
return image, label
# Usage example
# dataset = CustomDataset('data.csv', transform=transforms.ToTensor())#2. Image Dataset Example
class ImageDataset(Dataset):
def __init__(self, root_dir, annotations_file, transform=None):
self.root_dir = root_dir
self.annotations = pd.read_csv(annotations_file)
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 0])
image = Image.open(img_path).convert('RGB')
label = self.annotations.iloc[idx, 1]
if self.transform:
image = self.transform(image)
return image, label
# Text dataset example
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = str(self.texts[idx])
label = self.labels[idx]
# Text encoding
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'label': torch.tensor(label, dtype=torch.long)
}#3. In-Memory Dataset
class MemoryDataset(Dataset):
def __init__(self, data, labels, transform=None):
"""
In-memory dataset (suitable for small datasets)
"""
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
if self.transform:
sample = self.transform(sample)
return sample, label
# Create example data
data = torch.randn(1000, 3, 32, 32) # 1000 32x32 RGB images
labels = torch.randint(0, 10, (1000,)) # 10 class labels
dataset = MemoryDataset(data, labels)#DataLoader
#1. Basic Usage
from torch.utils.data import DataLoader
# Create data loader
dataloader = DataLoader(
dataset,
batch_size=32, # Batch size
shuffle=True, # Whether to shuffle data
num_workers=4, # Number of parallel loading processes
pin_memory=True, # Whether to pin memory for CUDA
drop_last=True # Whether to drop last incomplete batch
)
# Iterate through data
for batch_idx, (data, target) in enumerate(dataloader):
print(f"Batch {batch_idx}: data shape {data.shape}, label shape {target.shape}")
if batch_idx >= 2: # Only show first 3 batches
break#2. Custom Collate Function
def custom_collate_fn(batch):
"""
Custom batch collation function
"""
# Separate data and labels
data = [item[0] for item in batch]
labels = [item[1] for item in batch]
# Handle variable-length sequences
# Assume data is variable-length sequences
lengths = [len(seq) for seq in data]
max_length = max(lengths)
# Pad sequences
padded_data = []
for seq in data:
padded = torch.zeros(max_length, seq.size(-1))
padded[:len(seq)] = seq
padded_data.append(padded)
return torch.stack(padded_data), torch.tensor(labels), torch.tensor(lengths)
# Use custom collate function
dataloader = DataLoader(
dataset,
batch_size=32,
collate_fn=custom_collate_fn
)#3. Distributed Data Loading
from torch.utils.data.distributed import DistributedSampler
# Data loading during distributed training
def create_distributed_dataloader(dataset, batch_size, world_size, rank):
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=4,
pin_memory=True
)
return dataloader, sampler#Data Transforms (Transforms)
#1. Image Transforms
import torchvision.transforms as transforms
from torchvision.transforms import functional as F
# Basic transforms
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize
transforms.RandomHorizontalFlip(0.5), # Random horizontal flip
transforms.RandomRotation(10), # Random rotation
transforms.ColorJitter( # Color jitter
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
transforms.ToTensor(), # Convert to tensor
transforms.Normalize( # Normalize
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Advanced transforms
advanced_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomApply([
transforms.GaussianBlur(kernel_size=3)
], p=0.3),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.RandomErasing(p=0.2) # Random erasing
])#2. Custom Transforms
class AddGaussianNoise:
"""Add Gaussian noise"""
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'
class Cutout:
"""Random occlusion"""
def __init__(self, n_holes, length):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img = img * mask
return img
# Use custom transforms
custom_transform = transforms.Compose([
transforms.ToTensor(),
AddGaussianNoise(0., 0.1),
Cutout(n_holes=1, length=16)
])#3. Data Augmentation Strategy
# Training data augmentation
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Validation data processing
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Test-time augmentation (TTA)
tta_transforms = [
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(p=1.0),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
]#Built-in Datasets
#1. Computer Vision Datasets
import torchvision.datasets as datasets
# CIFAR-10
cifar10_train = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=train_transform
)
cifar10_test = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=val_transform
)
# ImageNet
imagenet_train = datasets.ImageNet(
root='./data/imagenet',
split='train',
transform=train_transform
)
# MNIST
mnist_train = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
)#2. Natural Language Processing Datasets
import torchtext.datasets as text_datasets
from torchtext.data.utils import get_tokenizer
# IMDB movie review dataset
train_iter, test_iter = text_datasets.IMDB(split=('train', 'test'))
# Process text data
tokenizer = get_tokenizer('basic_english')
def process_text(text_iter):
data = []
for label, text in text_iter:
tokens = tokenizer(text)
data.append((tokens, label))
return data
train_data = process_text(train_iter)#Data Preprocessing Techniques
#1. Data Standardization
def compute_mean_std(dataset):
"""Compute mean and standard deviation of dataset"""
dataloader = DataLoader(dataset, batch_size=100, shuffle=False)
mean = torch.zeros(3)
std = torch.zeros(3)
total_samples = 0
for data, _ in dataloader:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
total_samples += batch_samples
mean /= total_samples
std /= total_samples
return mean, std
# Usage example
# mean, std = compute_mean_std(dataset)
# print(f"Mean: {mean}, Std: {std}")#2. Data Balancing
from torch.utils.data import WeightedRandomSampler
from collections import Counter
def create_balanced_sampler(dataset):
"""Create balanced sampler"""
# Count samples per class
labels = [dataset[i][1] for i in range(len(dataset))]
class_counts = Counter(labels)
# Compute weights
total_samples = len(labels)
class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
# Assign weight to each sample
sample_weights = [class_weights[label] for label in labels]
# Create sampler
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
return sampler
# Use balanced sampler
# sampler = create_balanced_sampler(dataset)
# dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)#3. Data Splitting
from torch.utils.data import random_split
def split_dataset(dataset, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
"""Split dataset"""
assert train_ratio + val_ratio + test_ratio == 1.0
total_size = len(dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
dataset, [train_size, val_size, test_size]
)
return train_dataset, val_dataset, test_dataset
# Usage example
# train_data, val_data, test_data = split_dataset(dataset)#Advanced Data Processing
#1. Multi-Modal Data Processing
class MultiModalDataset(Dataset):
def __init__(self, image_paths, texts, labels, image_transform=None, text_tokenizer=None):
self.image_paths = image_paths
self.texts = texts
self.labels = labels
self.image_transform = image_transform
self.text_tokenizer = text_tokenizer
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
# Load image
image = Image.open(self.image_paths[idx]).convert('RGB')
if self.image_transform:
image = self.image_transform(image)
# Process text
text = self.texts[idx]
if self.text_tokenizer:
text_tokens = self.text_tokenizer(text)
else:
text_tokens = text
label = self.labels[idx]
return {
'image': image,
'text': text_tokens,
'label': label
}#2. Online Data Augmentation
class OnlineAugmentation:
def __init__(self, transforms_list, probabilities):
self.transforms = transforms_list
self.probs = probabilities
def __call__(self, image):
for transform, prob in zip(self.transforms, self.probs):
if torch.rand(1) < prob:
image = transform(image)
return image
# Use online augmentation
online_aug = OnlineAugmentation(
transforms_list=[
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1)
],
probabilities=[0.5, 0.3, 0.4]
)#3. Caching Mechanism
class CachedDataset(Dataset):
def __init__(self, dataset, cache_size=1000):
self.dataset = dataset
self.cache = {}
self.cache_size = cache_size
self.access_count = {}
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
if idx in self.cache:
self.access_count[idx] += 1
return self.cache[idx]
# Load data
data = self.dataset[idx]
# Cache management
if len(self.cache) >= self.cache_size:
# Remove least accessed item
lru_idx = min(self.access_count, key=self.access_count.get)
del self.cache[lru_idx]
del self.access_count[lru_idx]
# Add to cache
self.cache[idx] = data
self.access_count[idx] = 1
return data#Performance Optimization
#1. Data Loading Optimization
# Optimize data loading performance
def create_optimized_dataloader(dataset, batch_size, num_workers=None):
if num_workers is None:
num_workers = min(8, os.cpu_count())
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
persistent_workers=True, # Keep worker processes
prefetch_factor=2, # Prefetch factor
)
return dataloader#2. Memory Mapping
import mmap
class MemoryMappedDataset(Dataset):
def __init__(self, data_file, index_file):
# Use memory mapping to read large files
self.data_file = open(data_file, 'rb')
self.data_mmap = mmap.mmap(self.data_file.fileno(), 0, access=mmap.ACCESS_READ)
# Load indices
with open(index_file, 'r') as f:
self.indices = [int(line.strip()) for line in f]
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
offset = self.indices[idx]
# Read data from memory mapping
self.data_mmap.seek(offset)
# Read and parse data logic
# ...
pass
def __del__(self):
if hasattr(self, 'data_mmap'):
self.data_mmap.close()
if hasattr(self, 'data_file'):
self.data_file.close()#Practical Application Example
#1. Complete Image Classification Data Pipeline
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
def create_image_dataloaders(data_dir, batch_size=32, num_workers=4):
# Data transforms
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Create datasets
train_dataset = datasets.ImageFolder(
root=f'{data_dir}/train',
transform=train_transform
)
val_dataset = datasets.ImageFolder(
root=f'{data_dir}/val',
transform=val_transform
)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
return train_loader, val_loader, len(train_dataset.classes)
# Usage example
# train_loader, val_loader, num_classes = create_image_dataloaders('./data')#Summary
Data processing is the foundation of deep learning projects, and mastering PyTorch's data processing tools is crucial:
- Dataset Class: Learn to create custom datasets and handle different types of data
- DataLoader: Master batch loading, parallel processing, data shuffling, etc.
- Data Transforms: Proficiently use built-in and custom transforms for data augmentation
- Performance Optimization: Understand caching, memory mapping, parallel loading, etc.
- Practical Applications: Be able to build complete data processing pipelines
Good data processing not only improves model performance but also significantly speeds up the training process!