Skip to content

PyTorch Custom Operations

Custom Operations Overview

PyTorch provides powerful extension mechanisms that allow developers to create custom operations, layers, and functions. This is very useful for implementing new algorithms, optimizing performance, or integrating third-party libraries.

python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import numpy as np
import math

Custom Function

1. Basic Custom Function

python
class SquareFunction(Function):
    """Custom square function"""
    
    @staticmethod
    def forward(ctx, input):
        """Forward pass"""
        # Save input for backpropagation
        ctx.save_for_backward(input)
        return input ** 2
    
    @staticmethod
    def backward(ctx, grad_output):
        """Backward pass"""
        # Get saved input
        input, = ctx.saved_tensors
        # Compute gradient: d(x^2)/dx = 2x
        grad_input = grad_output * 2 * input
        return grad_input

# Create function interface
def square(input):
    return SquareFunction.apply(input)

# Test custom function
x = torch.randn(3, 4, requires_grad=True)
y = square(x)
loss = y.sum()
loss.backward()

print(f"Input: {x}")
print(f"Output: {y}")
print(f"Gradient: {x.grad}")

2. Multiple Input Multiple Output Function

python
class LinearFunction(Function):
    """Custom linear transform function"""
    
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        """Forward pass: y = xW^T + b"""
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        """Backward pass"""
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        
        # Compute input gradient
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        
        # Compute weight gradient
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        
        # Compute bias gradient
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        
        return grad_input, grad_weight, grad_bias

def linear(input, weight, bias=None):
    return LinearFunction.apply(input, weight, bias)

# Test multi-input function
input = torch.randn(5, 3, requires_grad=True)
weight = torch.randn(4, 3, requires_grad=True)
bias = torch.randn(4, requires_grad=True)

output = linear(input, weight, bias)
loss = output.sum()
loss.backward()

print(f"Input gradient shape: {input.grad.shape}")
print(f"Weight gradient shape: {weight.grad.shape}")
print(f"Bias gradient shape: {bias.grad.shape}")

3. Function with Context

python
class DropoutFunction(Function):
    """Custom Dropout function"""
    
    @staticmethod
    def forward(ctx, input, p=0.5, training=True):
        if training:
            # Generate random mask
            mask = torch.bernoulli(torch.full_like(input, 1 - p))
            ctx.save_for_backward(mask)
            ctx.p = p
            # Scale output to maintain expected value
            return input * mask / (1 - p)
        else:
            return input
    
    @staticmethod
    def backward(ctx, grad_output):
        mask, = ctx.saved_tensors
        p = ctx.p
        # Apply same mask and scaling
        grad_input = grad_output * mask / (1 - p)
        return grad_input, None, None

def dropout(input, p=0.5, training=True):
    return DropoutFunction.apply(input, p, training)

# Test Dropout
x = torch.randn(10, 5, requires_grad=True)
y = dropout(x, p=0.3, training=True)
loss = y.sum()
loss.backward()

print(f"Input: {x}")
print(f"Dropout output: {y}")
print(f"Gradient: {x.grad}")

Custom Module

1. Basic Custom Module

python
class CustomLinear(nn.Module):
    """Custom linear layer"""
    
    def __init__(self, in_features, out_features, bias=True):
        super(CustomLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Define parameters
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.randn(out_features))
        else:
            self.register_parameter('bias', None)
        
        # Initialize parameters
        self.reset_parameters()
    
    def reset_parameters(self):
        """Initialize parameters"""
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
    
    def forward(self, input):
        return linear(input, self.weight, self.bias)
    
    def extra_repr(self):
        """Extra string representation"""
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'

# Test custom linear layer
custom_linear = CustomLinear(10, 5)
x = torch.randn(3, 10)
y = custom_linear(x)

print(f"Custom linear layer: {custom_linear}")
print(f"Output shape: {y.shape}")

2. Complex Custom Module

python
class MultiHeadSelfAttention(nn.Module):
    """Custom multi-head self-attention layer"""
    
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        assert embed_dim % num_heads == 0
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Linear transform layers
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, embed_dim = x.size()
        
        # Compute Q, K, V
        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Compute attention weights
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention
        attn_output = torch.matmul(attn_weights, V)
        
        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, embed_dim
        )
        output = self.out_proj(attn_output)
        
        return output, attn_weights

# Test multi-head self-attention
attention = MultiHeadSelfAttention(embed_dim=256, num_heads=8)
x = torch.randn(2, 10, 256)
output, weights = attention(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")

3. Module with State

python
class RunningBatchNorm(nn.Module):
    """Custom batch normalization (with running statistics)"""
    
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(RunningBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        # Learnable parameters
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        
        # Running statistics (not parameters)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
    
    def forward(self, input):
        if self.training:
            # Training mode: compute batch statistics
            batch_mean = input.mean(dim=0)
            batch_var = input.var(dim=0, unbiased=False)
            
            # Update running statistics
            with torch.no_grad():
                self.num_batches_tracked += 1
                if self.momentum is None:
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:
                    exponential_average_factor = self.momentum
                
                self.running_mean = (1 - exponential_average_factor) * self.running_mean + \
                                   exponential_average_factor * batch_mean
                self.running_var = (1 - exponential_average_factor) * self.running_var + \
                                  exponential_average_factor * batch_var
            
            # Normalize using batch statistics
            normalized = (input - batch_mean) / torch.sqrt(batch_var + self.eps)
        else:
            # Evaluation mode: use running statistics
            normalized = (input - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        
        # Apply scale and shift
        return self.weight * normalized + self.bias

# Test custom batch normalization
bn = RunningBatchNorm(10)
x = torch.randn(32, 10)

# Training mode
bn.train()
y_train = bn(x)

# Evaluation mode
bn.eval()
y_eval = bn(x)

print(f"Training mode output: {y_train.mean():.4f}, {y_train.std():.4f}")
print(f"Evaluation mode output: {y_eval.mean():.4f}, {y_eval.std():.4f}")

Custom Loss Functions

1. Basic Custom Loss

python
class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class DiceLoss(nn.Module):
    """Dice Loss for segmentation tasks"""
    
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, inputs, targets):
        # Convert inputs to probabilities
        inputs = torch.sigmoid(inputs)
        
        # Flatten tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        # Compute Dice coefficient
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
        
        return 1 - dice

class ContrastiveLoss(nn.Module):
    """Contrastive loss for similarity learning"""
    
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    
    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean(
            (1 - label) * torch.pow(euclidean_distance, 2) +
            label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )
        return loss_contrastive

# Test custom loss functions
focal_loss = FocalLoss(alpha=1, gamma=2)
dice_loss = DiceLoss()
contrastive_loss = ContrastiveLoss(margin=1.0)

# Test Focal Loss
logits = torch.randn(10, 5)
targets = torch.randint(0, 5, (10,))
focal_loss_value = focal_loss(logits, targets)
print(f"Focal Loss: {focal_loss_value.item():.4f}")

# Test Dice Loss
pred_masks = torch.randn(2, 1, 64, 64)
true_masks = torch.randint(0, 2, (2, 1, 64, 64)).float()
dice_loss_value = dice_loss(pred_masks, true_masks)
print(f"Dice Loss: {dice_loss_value.item():.4f}")

Custom Optimizers

1. Basic Custom Optimizer

python
class CustomSGD(torch.optim.Optimizer):
    """Custom SGD optimizer"""
    
    def __init__(self, params, lr=1e-3, momentum=0, dampening=0, weight_decay=0):
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if momentum < 0.0:
            raise ValueError(f"Invalid momentum value: {momentum}")
        if weight_decay < 0.0:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
        
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay)
        super(CustomSGD, self).__init__(params, defaults)
    
    def step(self, closure=None):
        """Perform single optimization step"""
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                d_p = p.grad.data
                
                # Add weight decay
                if weight_decay != 0:
                    d_p = d_p.add(p.data, alpha=weight_decay)
                
                # Add momentum
                if momentum != 0:
                    param_state = self.state[p]
                    if len(param_state) == 0:
                        param_state['momentum_buffer'] = torch.zeros_like(p.data)
                    
                    buf = param_state['momentum_buffer']
                    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    d_p = buf
                
                # Update parameters
                p.data.add_(d_p, alpha=-group['lr'])
        
        return loss

class AdamW(torch.optim.Optimizer):
    """Custom AdamW optimizer"""
    
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(AdamW, self).__init__(params, defaults)
    
    def step(self, closure=None):
        """Perform single optimization step"""
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('AdamW does not support sparse gradients')
                
                state = self.state[p]
                
                # Initialize state
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']
                
                state['step'] += 1
                
                # Exponential moving average
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                # Bias correction
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                # Compute step size
                step_size = group['lr'] / bias_correction1
                bias_correction2_sqrt = math.sqrt(bias_correction2)
                
                # Weight decay (decoupled)
                p.data.mul_(1 - group['lr'] * group['weight_decay'])
                
                # Update parameters
                denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(group['eps'])
                p.data.addcdiv_(exp_avg, denom, value=-step_size)
        
        return loss

# Test custom optimizers
model = nn.Linear(10, 1)
custom_sgd = CustomSGD(model.parameters(), lr=0.01, momentum=0.9)
custom_adamw = AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

print(f"Custom SGD: {custom_sgd}")
print(f"Custom AdamW: {custom_adamw}")

Custom Data Types and Operations

1. Custom Tensor Operations

python
class ComplexTensor:
    """Custom complex tensor class"""
    
    def __init__(self, real, imag):
        self.real = real
        self.imag = imag
    
    def __add__(self, other):
        if isinstance(other, ComplexTensor):
            return ComplexTensor(self.real + other.real, self.imag + other.imag)
        else:
            return ComplexTensor(self.real + other, self.imag)
    
    def __mul__(self, other):
        if isinstance(other, ComplexTensor):
            # (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
            real = self.real * other.real - self.imag * other.imag
            imag = self.real * other.imag + self.imag * other.real
            return ComplexTensor(real, imag)
        else:
            return ComplexTensor(self.real * other, self.imag * other)
    
    def abs(self):
        """Compute magnitude of complex number"""
        return torch.sqrt(self.real ** 2 + self.imag ** 2)
    
    def conjugate(self):
        """Compute conjugate of complex number"""
        return ComplexTensor(self.real, -self.imag)
    
    def __repr__(self):
        return f"ComplexTensor(real={self.real}, imag={self.imag})"

# Test custom complex tensor
real1 = torch.tensor([1.0, 2.0])
imag1 = torch.tensor([3.0, 4.0])
complex1 = ComplexTensor(real1, imag1)

real2 = torch.tensor([5.0, 6.0])
imag2 = torch.tensor([7.0, 8.0])
complex2 = ComplexTensor(real2, imag2)

# Complex operations
result_add = complex1 + complex2
result_mul = complex1 * complex2
result_abs = complex1.abs()
result_conj = complex1.conjugate()

print(f"Complex 1: {complex1}")
print(f"Complex 2: {complex2}")
print(f"Addition: {result_add}")
print(f"Multiplication: {result_mul}")
print(f"Magnitude: {result_abs}")
print(f"Conjugate: {result_conj}")

2. Custom Activation Functions

python
class Swish(nn.Module):
    """Swish activation function: f(x) = x * sigmoid(βx)"""
    
    def __init__(self, beta=1.0):
        super(Swish, self).__init__()
        self.beta = nn.Parameter(torch.tensor(beta))
    
    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)

class Mish(nn.Module):
    """Mish activation function: f(x) = x * tanh(softplus(x))"""
    
    def forward(self, x):
        return x * torch.tanh(F.softplus(x))

class GELU(nn.Module):
    """Custom GELU activation function"""
    
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class PReLU(nn.Module):
    """Parametric ReLU"""
    
    def __init__(self, num_parameters=1, init=0.25):
        super(PReLU, self).__init__()
        self.num_parameters = num_parameters
        self.weight = nn.Parameter(torch.full((num_parameters,), init))
    
    def forward(self, x):
        return F.prelu(x, self.weight)

# Test custom activation functions
x = torch.randn(10, 5)

swish = Swish(beta=1.0)
mish = Mish()
gelu = GELU()
prelu = PReLU(num_parameters=5)

print(f"Input: {x[0]}")
print(f"Swish: {swish(x)[0]}")
print(f"Mish: {mish(x)[0]}")
print(f"GELU: {gelu(x)[0]}")
print(f"PReLU: {prelu(x)[0]}")

C++ Extensions

1. Basic C++ Extension

cpp
// custom_ops.cpp
#include <torch/extension.h>
#include <vector>

torch::Tensor add_forward(torch::Tensor input1, torch::Tensor input2) {
    return input1 + input2;
}

std::vector<torch::Tensor> add_backward(torch::Tensor grad_output) {
    return {grad_output, grad_output};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("add_forward", &add_forward, "Add forward");
    m.def("add_backward", &add_backward, "Add backward");
}
python
# setup.py for C++ extension
from setuptools import setup
from pybind11.setup_helpers import Pybind11Extension, build_ext
from torch.utils.cpp_extension import BuildExtension, CppExtension

ext_modules = [
    CppExtension(
        "custom_ops",
        ["custom_ops.cpp"],
    ),
]

setup(
    name="custom_ops",
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension},
)

2. CUDA Extension Example

python
# Using JIT-compiled CUDA extension
from torch.utils.cpp_extension import load

cuda_source = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void add_kernel(float* a, float* b, float* c, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        c[idx] = a[idx] + b[idx];
    }
}

torch::Tensor cuda_add(torch::Tensor a, torch::Tensor b) {
    auto c = torch::zeros_like(a);
    
    int n = a.numel();
    int threads = 256;
    int blocks = (n + threads - 1) / threads;
    
    add_kernel<<<blocks, threads>>>(
        a.data_ptr<float>(),
        b.data_ptr<float>(),
        c.data_ptr<float>(),
        n
    );
    
    return c;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("cuda_add", &cuda_add, "CUDA add");
}
"""

# JIT compile CUDA extension
cuda_ops = load(
    name="cuda_ops",
    sources=["cuda_ops.cu"],
    verbose=True
)

# Use CUDA extension
if torch.cuda.is_available():
    a = torch.randn(1000, device='cuda')
    b = torch.randn(1000, device='cuda')
    c = cuda_ops.cuda_add(a, b)
    print(f"CUDA addition result: {c[:5]}")

Performance Optimization Techniques

1. Memory Optimization

python
class MemoryEfficientFunction(Function):
    """Memory efficient custom function"""
    
    @staticmethod
    def forward(ctx, input, weight):
        # Only save necessary information for backpropagation
        ctx.input_shape = input.shape
        ctx.weight_shape = weight.shape
        
        # Use in-place operation to save memory
        output = torch.mm(input, weight.t())
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        # Recompute instead of saving intermediate results
        input_shape = ctx.input_shape
        weight_shape = ctx.weight_shape
        
        grad_input = grad_weight = None
        
        if ctx.needs_input_grad[0]:
            grad_input = torch.zeros(input_shape)
            # Compute input gradient logic
        
        if ctx.needs_input_grad[1]:
            grad_weight = torch.zeros(weight_shape)
            # Compute weight gradient logic
        
        return grad_input, grad_weight

class CheckpointFunction(Function):
    """Function using checkpoints"""
    
    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
        check_backward_validity(args)
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
        
        ctx.save_for_backward(*args)
        
        with torch.no_grad():
            outputs = run_function(*args)
        
        return outputs
    
    @staticmethod
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad()")
        
        inputs = ctx.saved_tensors
        
        # Restore RNG state
        if ctx.preserve_rng_state:
            if ctx.had_cuda_in_fwd:
                set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
            torch.set_rng_state(ctx.fwd_cpu_state)
        
        with torch.enable_grad():
            outputs = ctx.run_function(*inputs)
        
        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        
        torch.autograd.backward(outputs, args)
        
        return (None, None) + tuple(inp.grad for inp in inputs)

def checkpoint(function, *args, **kwargs):
    """Checkpoint wrapper"""
    preserve = kwargs.pop('preserve_rng_state', True)
    return CheckpointFunction.apply(function, preserve, *args)

2. Numerical Stability

python
class NumericallyStableFunction(Function):
    """Numerically stable custom function"""
    
    @staticmethod
    def forward(ctx, input):
        # Use numerically stable implementation
        # For example: log-sum-exp trick
        max_val = input.max(dim=-1, keepdim=True)[0]
        shifted_input = input - max_val
        exp_shifted = torch.exp(shifted_input)
        sum_exp = exp_shifted.sum(dim=-1, keepdim=True)
        log_sum_exp = torch.log(sum_exp) + max_val
        
        ctx.save_for_backward(exp_shifted, sum_exp)
        return log_sum_exp
    
    @staticmethod
    def backward(ctx, grad_output):
        exp_shifted, sum_exp = ctx.saved_tensors
        # Compute softmax gradient
        softmax = exp_shifted / sum_exp
        grad_input = grad_output * softmax
        return grad_input

def stable_log_sum_exp(input):
    return NumericallyStableFunction.apply(input)

# Test numerical stability
x = torch.tensor([1000.0, 1001.0, 1002.0])  # Large values
stable_result = stable_log_sum_exp(x)
print(f"Numerically stable result: {stable_result}")

Summary

Custom operations are a powerful feature of PyTorch. This chapter introduced:

  1. Custom Function: Custom functions implementing forward and backward propagation
  2. Custom Module: Creating reusable neural network components
  3. Custom Loss Functions: Implementing loss functions for specific tasks
  4. Custom Optimizers: Implementing new optimization algorithms
  5. Custom Data Types: Extending PyTorch's data processing capabilities
  6. C++/CUDA Extensions: High-performance low-level implementations
  7. Performance Optimization: Memory efficiency and numerical stability techniques

Mastering these techniques will help you extend PyTorch's functionality and implement innovative deep learning algorithms!

Content is for learning and research only.