PyTorch Custom Operations
Custom Operations Overview
PyTorch provides powerful extension mechanisms that allow developers to create custom operations, layers, and functions. This is very useful for implementing new algorithms, optimizing performance, or integrating third-party libraries.
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import numpy as np
import mathCustom Function
1. Basic Custom Function
python
class SquareFunction(Function):
"""Custom square function"""
@staticmethod
def forward(ctx, input):
"""Forward pass"""
# Save input for backpropagation
ctx.save_for_backward(input)
return input ** 2
@staticmethod
def backward(ctx, grad_output):
"""Backward pass"""
# Get saved input
input, = ctx.saved_tensors
# Compute gradient: d(x^2)/dx = 2x
grad_input = grad_output * 2 * input
return grad_input
# Create function interface
def square(input):
return SquareFunction.apply(input)
# Test custom function
x = torch.randn(3, 4, requires_grad=True)
y = square(x)
loss = y.sum()
loss.backward()
print(f"Input: {x}")
print(f"Output: {y}")
print(f"Gradient: {x.grad}")2. Multiple Input Multiple Output Function
python
class LinearFunction(Function):
"""Custom linear transform function"""
@staticmethod
def forward(ctx, input, weight, bias=None):
"""Forward pass: y = xW^T + b"""
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
"""Backward pass"""
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# Compute input gradient
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
# Compute weight gradient
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
# Compute bias gradient
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
def linear(input, weight, bias=None):
return LinearFunction.apply(input, weight, bias)
# Test multi-input function
input = torch.randn(5, 3, requires_grad=True)
weight = torch.randn(4, 3, requires_grad=True)
bias = torch.randn(4, requires_grad=True)
output = linear(input, weight, bias)
loss = output.sum()
loss.backward()
print(f"Input gradient shape: {input.grad.shape}")
print(f"Weight gradient shape: {weight.grad.shape}")
print(f"Bias gradient shape: {bias.grad.shape}")3. Function with Context
python
class DropoutFunction(Function):
"""Custom Dropout function"""
@staticmethod
def forward(ctx, input, p=0.5, training=True):
if training:
# Generate random mask
mask = torch.bernoulli(torch.full_like(input, 1 - p))
ctx.save_for_backward(mask)
ctx.p = p
# Scale output to maintain expected value
return input * mask / (1 - p)
else:
return input
@staticmethod
def backward(ctx, grad_output):
mask, = ctx.saved_tensors
p = ctx.p
# Apply same mask and scaling
grad_input = grad_output * mask / (1 - p)
return grad_input, None, None
def dropout(input, p=0.5, training=True):
return DropoutFunction.apply(input, p, training)
# Test Dropout
x = torch.randn(10, 5, requires_grad=True)
y = dropout(x, p=0.3, training=True)
loss = y.sum()
loss.backward()
print(f"Input: {x}")
print(f"Dropout output: {y}")
print(f"Gradient: {x.grad}")Custom Module
1. Basic Custom Module
python
class CustomLinear(nn.Module):
"""Custom linear layer"""
def __init__(self, in_features, out_features, bias=True):
super(CustomLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
# Define parameters
self.weight = nn.Parameter(torch.randn(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.randn(out_features))
else:
self.register_parameter('bias', None)
# Initialize parameters
self.reset_parameters()
def reset_parameters(self):
"""Initialize parameters"""
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
return linear(input, self.weight, self.bias)
def extra_repr(self):
"""Extra string representation"""
return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'
# Test custom linear layer
custom_linear = CustomLinear(10, 5)
x = torch.randn(3, 10)
y = custom_linear(x)
print(f"Custom linear layer: {custom_linear}")
print(f"Output shape: {y.shape}")2. Complex Custom Module
python
class MultiHeadSelfAttention(nn.Module):
"""Custom multi-head self-attention layer"""
def __init__(self, embed_dim, num_heads, dropout=0.1):
super(MultiHeadSelfAttention, self).__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Linear transform layers
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.head_dim)
def forward(self, x, mask=None):
batch_size, seq_len, embed_dim = x.size()
# Compute Q, K, V
Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# Apply mask
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Compute attention weights
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention
attn_output = torch.matmul(attn_weights, V)
# Reshape and project output
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, embed_dim
)
output = self.out_proj(attn_output)
return output, attn_weights
# Test multi-head self-attention
attention = MultiHeadSelfAttention(embed_dim=256, num_heads=8)
x = torch.randn(2, 10, 256)
output, weights = attention(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")3. Module with State
python
class RunningBatchNorm(nn.Module):
"""Custom batch normalization (with running statistics)"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(RunningBatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# Learnable parameters
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
# Running statistics (not parameters)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
def forward(self, input):
if self.training:
# Training mode: compute batch statistics
batch_mean = input.mean(dim=0)
batch_var = input.var(dim=0, unbiased=False)
# Update running statistics
with torch.no_grad():
self.num_batches_tracked += 1
if self.momentum is None:
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:
exponential_average_factor = self.momentum
self.running_mean = (1 - exponential_average_factor) * self.running_mean + \
exponential_average_factor * batch_mean
self.running_var = (1 - exponential_average_factor) * self.running_var + \
exponential_average_factor * batch_var
# Normalize using batch statistics
normalized = (input - batch_mean) / torch.sqrt(batch_var + self.eps)
else:
# Evaluation mode: use running statistics
normalized = (input - self.running_mean) / torch.sqrt(self.running_var + self.eps)
# Apply scale and shift
return self.weight * normalized + self.bias
# Test custom batch normalization
bn = RunningBatchNorm(10)
x = torch.randn(32, 10)
# Training mode
bn.train()
y_train = bn(x)
# Evaluation mode
bn.eval()
y_eval = bn(x)
print(f"Training mode output: {y_train.mean():.4f}, {y_train.std():.4f}")
print(f"Evaluation mode output: {y_eval.mean():.4f}, {y_eval.std():.4f}")Custom Loss Functions
1. Basic Custom Loss
python
class FocalLoss(nn.Module):
"""Focal Loss for handling class imbalance"""
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class DiceLoss(nn.Module):
"""Dice Loss for segmentation tasks"""
def __init__(self, smooth=1e-6):
super(DiceLoss, self).__init__()
self.smooth = smooth
def forward(self, inputs, targets):
# Convert inputs to probabilities
inputs = torch.sigmoid(inputs)
# Flatten tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
# Compute Dice coefficient
intersection = (inputs * targets).sum()
dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
return 1 - dice
class ContrastiveLoss(nn.Module):
"""Contrastive loss for similarity learning"""
def __init__(self, margin=1.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2)
loss_contrastive = torch.mean(
(1 - label) * torch.pow(euclidean_distance, 2) +
label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
)
return loss_contrastive
# Test custom loss functions
focal_loss = FocalLoss(alpha=1, gamma=2)
dice_loss = DiceLoss()
contrastive_loss = ContrastiveLoss(margin=1.0)
# Test Focal Loss
logits = torch.randn(10, 5)
targets = torch.randint(0, 5, (10,))
focal_loss_value = focal_loss(logits, targets)
print(f"Focal Loss: {focal_loss_value.item():.4f}")
# Test Dice Loss
pred_masks = torch.randn(2, 1, 64, 64)
true_masks = torch.randint(0, 2, (2, 1, 64, 64)).float()
dice_loss_value = dice_loss(pred_masks, true_masks)
print(f"Dice Loss: {dice_loss_value.item():.4f}")Custom Optimizers
1. Basic Custom Optimizer
python
class CustomSGD(torch.optim.Optimizer):
"""Custom SGD optimizer"""
def __init__(self, params, lr=1e-3, momentum=0, dampening=0, weight_decay=0):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay)
super(CustomSGD, self).__init__(params, defaults)
def step(self, closure=None):
"""Perform single optimization step"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
# Add weight decay
if weight_decay != 0:
d_p = d_p.add(p.data, alpha=weight_decay)
# Add momentum
if momentum != 0:
param_state = self.state[p]
if len(param_state) == 0:
param_state['momentum_buffer'] = torch.zeros_like(p.data)
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
d_p = buf
# Update parameters
p.data.add_(d_p, alpha=-group['lr'])
return loss
class AdamW(torch.optim.Optimizer):
"""Custom AdamW optimizer"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(AdamW, self).__init__(params, defaults)
def step(self, closure=None):
"""Perform single optimization step"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
state = self.state[p]
# Initialize state
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Exponential moving average
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Bias correction
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Compute step size
step_size = group['lr'] / bias_correction1
bias_correction2_sqrt = math.sqrt(bias_correction2)
# Weight decay (decoupled)
p.data.mul_(1 - group['lr'] * group['weight_decay'])
# Update parameters
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(group['eps'])
p.data.addcdiv_(exp_avg, denom, value=-step_size)
return loss
# Test custom optimizers
model = nn.Linear(10, 1)
custom_sgd = CustomSGD(model.parameters(), lr=0.01, momentum=0.9)
custom_adamw = AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
print(f"Custom SGD: {custom_sgd}")
print(f"Custom AdamW: {custom_adamw}")Custom Data Types and Operations
1. Custom Tensor Operations
python
class ComplexTensor:
"""Custom complex tensor class"""
def __init__(self, real, imag):
self.real = real
self.imag = imag
def __add__(self, other):
if isinstance(other, ComplexTensor):
return ComplexTensor(self.real + other.real, self.imag + other.imag)
else:
return ComplexTensor(self.real + other, self.imag)
def __mul__(self, other):
if isinstance(other, ComplexTensor):
# (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
real = self.real * other.real - self.imag * other.imag
imag = self.real * other.imag + self.imag * other.real
return ComplexTensor(real, imag)
else:
return ComplexTensor(self.real * other, self.imag * other)
def abs(self):
"""Compute magnitude of complex number"""
return torch.sqrt(self.real ** 2 + self.imag ** 2)
def conjugate(self):
"""Compute conjugate of complex number"""
return ComplexTensor(self.real, -self.imag)
def __repr__(self):
return f"ComplexTensor(real={self.real}, imag={self.imag})"
# Test custom complex tensor
real1 = torch.tensor([1.0, 2.0])
imag1 = torch.tensor([3.0, 4.0])
complex1 = ComplexTensor(real1, imag1)
real2 = torch.tensor([5.0, 6.0])
imag2 = torch.tensor([7.0, 8.0])
complex2 = ComplexTensor(real2, imag2)
# Complex operations
result_add = complex1 + complex2
result_mul = complex1 * complex2
result_abs = complex1.abs()
result_conj = complex1.conjugate()
print(f"Complex 1: {complex1}")
print(f"Complex 2: {complex2}")
print(f"Addition: {result_add}")
print(f"Multiplication: {result_mul}")
print(f"Magnitude: {result_abs}")
print(f"Conjugate: {result_conj}")2. Custom Activation Functions
python
class Swish(nn.Module):
"""Swish activation function: f(x) = x * sigmoid(βx)"""
def __init__(self, beta=1.0):
super(Swish, self).__init__()
self.beta = nn.Parameter(torch.tensor(beta))
def forward(self, x):
return x * torch.sigmoid(self.beta * x)
class Mish(nn.Module):
"""Mish activation function: f(x) = x * tanh(softplus(x))"""
def forward(self, x):
return x * torch.tanh(F.softplus(x))
class GELU(nn.Module):
"""Custom GELU activation function"""
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class PReLU(nn.Module):
"""Parametric ReLU"""
def __init__(self, num_parameters=1, init=0.25):
super(PReLU, self).__init__()
self.num_parameters = num_parameters
self.weight = nn.Parameter(torch.full((num_parameters,), init))
def forward(self, x):
return F.prelu(x, self.weight)
# Test custom activation functions
x = torch.randn(10, 5)
swish = Swish(beta=1.0)
mish = Mish()
gelu = GELU()
prelu = PReLU(num_parameters=5)
print(f"Input: {x[0]}")
print(f"Swish: {swish(x)[0]}")
print(f"Mish: {mish(x)[0]}")
print(f"GELU: {gelu(x)[0]}")
print(f"PReLU: {prelu(x)[0]}")C++ Extensions
1. Basic C++ Extension
cpp
// custom_ops.cpp
#include <torch/extension.h>
#include <vector>
torch::Tensor add_forward(torch::Tensor input1, torch::Tensor input2) {
return input1 + input2;
}
std::vector<torch::Tensor> add_backward(torch::Tensor grad_output) {
return {grad_output, grad_output};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("add_forward", &add_forward, "Add forward");
m.def("add_backward", &add_backward, "Add backward");
}python
# setup.py for C++ extension
from setuptools import setup
from pybind11.setup_helpers import Pybind11Extension, build_ext
from torch.utils.cpp_extension import BuildExtension, CppExtension
ext_modules = [
CppExtension(
"custom_ops",
["custom_ops.cpp"],
),
]
setup(
name="custom_ops",
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
)2. CUDA Extension Example
python
# Using JIT-compiled CUDA extension
from torch.utils.cpp_extension import load
cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__ void add_kernel(float* a, float* b, float* c, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
c[idx] = a[idx] + b[idx];
}
}
torch::Tensor cuda_add(torch::Tensor a, torch::Tensor b) {
auto c = torch::zeros_like(a);
int n = a.numel();
int threads = 256;
int blocks = (n + threads - 1) / threads;
add_kernel<<<blocks, threads>>>(
a.data_ptr<float>(),
b.data_ptr<float>(),
c.data_ptr<float>(),
n
);
return c;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cuda_add", &cuda_add, "CUDA add");
}
"""
# JIT compile CUDA extension
cuda_ops = load(
name="cuda_ops",
sources=["cuda_ops.cu"],
verbose=True
)
# Use CUDA extension
if torch.cuda.is_available():
a = torch.randn(1000, device='cuda')
b = torch.randn(1000, device='cuda')
c = cuda_ops.cuda_add(a, b)
print(f"CUDA addition result: {c[:5]}")Performance Optimization Techniques
1. Memory Optimization
python
class MemoryEfficientFunction(Function):
"""Memory efficient custom function"""
@staticmethod
def forward(ctx, input, weight):
# Only save necessary information for backpropagation
ctx.input_shape = input.shape
ctx.weight_shape = weight.shape
# Use in-place operation to save memory
output = torch.mm(input, weight.t())
return output
@staticmethod
def backward(ctx, grad_output):
# Recompute instead of saving intermediate results
input_shape = ctx.input_shape
weight_shape = ctx.weight_shape
grad_input = grad_weight = None
if ctx.needs_input_grad[0]:
grad_input = torch.zeros(input_shape)
# Compute input gradient logic
if ctx.needs_input_grad[1]:
grad_weight = torch.zeros(weight_shape)
# Compute weight gradient logic
return grad_input, grad_weight
class CheckpointFunction(Function):
"""Function using checkpoints"""
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
ctx.had_cuda_in_fwd = False
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad()")
inputs = ctx.saved_tensors
# Restore RNG state
if ctx.preserve_rng_state:
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
torch.set_rng_state(ctx.fwd_cpu_state)
with torch.enable_grad():
outputs = ctx.run_function(*inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
return (None, None) + tuple(inp.grad for inp in inputs)
def checkpoint(function, *args, **kwargs):
"""Checkpoint wrapper"""
preserve = kwargs.pop('preserve_rng_state', True)
return CheckpointFunction.apply(function, preserve, *args)2. Numerical Stability
python
class NumericallyStableFunction(Function):
"""Numerically stable custom function"""
@staticmethod
def forward(ctx, input):
# Use numerically stable implementation
# For example: log-sum-exp trick
max_val = input.max(dim=-1, keepdim=True)[0]
shifted_input = input - max_val
exp_shifted = torch.exp(shifted_input)
sum_exp = exp_shifted.sum(dim=-1, keepdim=True)
log_sum_exp = torch.log(sum_exp) + max_val
ctx.save_for_backward(exp_shifted, sum_exp)
return log_sum_exp
@staticmethod
def backward(ctx, grad_output):
exp_shifted, sum_exp = ctx.saved_tensors
# Compute softmax gradient
softmax = exp_shifted / sum_exp
grad_input = grad_output * softmax
return grad_input
def stable_log_sum_exp(input):
return NumericallyStableFunction.apply(input)
# Test numerical stability
x = torch.tensor([1000.0, 1001.0, 1002.0]) # Large values
stable_result = stable_log_sum_exp(x)
print(f"Numerically stable result: {stable_result}")Summary
Custom operations are a powerful feature of PyTorch. This chapter introduced:
- Custom Function: Custom functions implementing forward and backward propagation
- Custom Module: Creating reusable neural network components
- Custom Loss Functions: Implementing loss functions for specific tasks
- Custom Optimizers: Implementing new optimization algorithms
- Custom Data Types: Extending PyTorch's data processing capabilities
- C++/CUDA Extensions: High-performance low-level implementations
- Performance Optimization: Memory efficiency and numerical stability techniques
Mastering these techniques will help you extend PyTorch's functionality and implement innovative deep learning algorithms!