Skip to content

PyTorch Automatic Differentiation

What is Automatic Differentiation?

Automatic Differentiation (AutoGrad) is one of PyTorch's core features. It can automatically compute gradients for tensor operations, which is crucial for the backpropagation algorithm in deep learning.

python
import torch

# Create tensor requiring gradient
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2 + 3 * x + 1

# Compute gradients
y.backward()

print(f"x = {x}")
print(f"y = {y}")
print(f"dy/dx = {x.grad}")  # Should be 2*2 + 3 = 7

Computational Graph

PyTorch uses dynamic computational graphs to track operations and compute gradients:

python
import torch

x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)

# Build computational graph
z = x * y + x ** 2
w = z.mean()

print(f"Functions in computational graph: {w.grad_fn}")
print(f"Gradient function for z: {z.grad_fn}")

# Backpropagation
w.backward()

print(f"dw/dx = {x.grad}")  # 2*x + y = 2*1 + 2 = 4
print(f"dw/dy = {y.grad}")  # x = 1

requires_grad Attribute

1. Basic Usage

python
# Specify during creation
x = torch.randn(3, 4, requires_grad=True)

# Set later
y = torch.randn(3, 4)
y.requires_grad_(True)  # In-place modification

# Check if gradient is needed
print(f"x requires grad: {x.requires_grad}")
print(f"y requires grad: {y.requires_grad}")

2. Gradient Propagation Rules

python
x = torch.randn(2, 2, requires_grad=True)
y = torch.randn(2, 2, requires_grad=False)

# If one operand requires gradient, result requires gradient
z = x + y
print(f"z requires grad: {z.requires_grad}")  # True

# When all operands don't require gradient, result doesn't either
a = torch.randn(2, 2)
b = torch.randn(2, 2)
c = a + b
print(f"c requires grad: {c.requires_grad}")  # False

Gradient Computation

1. Gradients for Scalar Functions

python
# Single variable function
x = torch.tensor(3.0, requires_grad=True)
y = x ** 3 - 2 * x ** 2 + x - 1

y.backward()
print(f"dy/dx = {x.grad}")  # 3*9 - 4*3 + 1 = 16

# Multi-variable function
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)
z = x ** 2 + y ** 2 + 2 * x * y

z.backward()
print(f"dz/dx = {x.grad}")  # 2*x + 2*y = 2*1 + 2*2 = 6
print(f"dz/dy = {y.grad}")  # 2*y + 2*x = 2*2 + 2*1 = 6

2. Gradients for Vector Functions

python
# For non-scalar output, need to provide gradient parameter
x = torch.randn(3, requires_grad=True)
y = x * 2

# Need to provide gradient of same shape as y
gradient = torch.ones_like(y)
y.backward(gradient)

print(f"x gradients: {x.grad}")  # Should all be 2

3. Jacobian-Vector Product

python
x = torch.randn(3, requires_grad=True)
y = x ** 2

# Compute Jacobian-vector product J^T * v
v = torch.tensor([1.0, 1.0, 1.0])
y.backward(v)

print(f"Jacobian-vector product: {x.grad}")  # 2*x

Gradient Accumulation

python
x = torch.tensor(1.0, requires_grad=True)

# First computation
y1 = x ** 2
y1.backward()
print(f"First gradient: {x.grad}")  # 2

# Second computation (gradients accumulate)
y2 = x ** 3
y2.backward()
print(f"Accumulated gradient: {x.grad}")  # 2 + 3 = 5

# Zero gradients
x.grad.zero_()
print(f"After zeroing gradient: {x.grad}")  # 0

Higher-Order Derivatives

python
# Compute second derivative
x = torch.tensor(2.0, requires_grad=True)
y = x ** 4

# First derivative
grad1 = torch.autograd.grad(y, x, create_graph=True)[0]
print(f"First derivative: {grad1}")  # 4 * x^3 = 32

# Second derivative
grad2 = torch.autograd.grad(grad1, x)[0]
print(f"Second derivative: {grad2}")  # 12 * x^2 = 48

Controlling Gradient Computation

1. torch.no_grad()

python
x = torch.randn(3, requires_grad=True)

# In no_grad context, computational graph is not built
with torch.no_grad():
    y = x ** 2
    print(f"y requires grad: {y.requires_grad}")  # False

# Decorator form
@torch.no_grad()
def inference(x):
    return x ** 2 + 1

result = inference(x)
print(f"Inference result requires grad: {result.requires_grad}")  # False

2. detach() Method

python
x = torch.randn(3, requires_grad=True)
y = x ** 2

# Detach tensor, stop gradient flow
y_detached = y.detach()
z = y_detached * 2

print(f"y requires grad: {y.requires_grad}")  # True
print(f"y_detached requires grad: {y_detached.requires_grad}")  # False
print(f"z requires grad: {z.requires_grad}")  # False

3. torch.set_grad_enabled()

python
# Global gradient control
torch.set_grad_enabled(False)
x = torch.randn(3, requires_grad=True)
y = x ** 2
print(f"y requires grad with global disabled: {y.requires_grad}")  # False

torch.set_grad_enabled(True)  # Re-enable

Custom autograd Function

python
class MySquare(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        # Save input for backpropagation
        ctx.save_for_backward(input)
        return input ** 2
    
    @staticmethod
    def backward(ctx, grad_output):
        # Get saved input
        input, = ctx.saved_tensors
        # Compute gradient
        return grad_output * 2 * input

# Use custom function
my_square = MySquare.apply

x = torch.tensor(3.0, requires_grad=True)
y = my_square(x)
y.backward()

print(f"Custom function gradient: {x.grad}")  # 6

Gradient Checking

python
def gradient_check(func, inputs, eps=1e-6):
    """Numerical gradient check"""
    # Parse gradients
    outputs = func(*inputs)
    if outputs.numel() != 1:
        outputs = outputs.sum()
    
    analytical_grads = torch.autograd.grad(outputs, inputs)
    
    # Numerical gradients
    numerical_grads = []
    for i, inp in enumerate(inputs):
        grad = torch.zeros_like(inp)
        it = torch.nditer(inp.detach().numpy(), flags=['multi_index'])
        
        while not it.finished:
            idx = it.multi_index
            old_value = inp[idx].item()
            
            # f(x + eps)
            inp[idx] = old_value + eps
            pos_output = func(*inputs).sum()
            
            # f(x - eps)
            inp[idx] = old_value - eps
            neg_output = func(*inputs).sum()
            
            # Numerical gradient
            grad[idx] = (pos_output - neg_output) / (2 * eps)
            
            # Restore original value
            inp[idx] = old_value
            it.iternext()
        
        numerical_grads.append(grad)
    
    # Compare gradients
    for i, (analytical, numerical) in enumerate(zip(analytical_grads, numerical_grads)):
        diff = torch.abs(analytical - numerical).max()
        print(f"Gradient difference for input {i}: {diff:.8f}")

# Test
def test_func(x, y):
    return x ** 2 + y ** 3

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
gradient_check(test_func, [x, y])

Common Problems and Solutions

1. Gradient Explosion

python
# Gradient clipping
def clip_gradients(model, max_norm):
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

# Usage example
# clip_gradients(model, max_norm=1.0)

2. Gradient Vanishing

python
# Check gradient magnitude
def check_gradients(model):
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    print(f"Total gradient norm: {total_norm}")

3. Memory Leaks

python
# Clean up gradients in time
optimizer.zero_grad()

# Detach tensors that don't need gradients
prediction = model(x).detach()

# Use torch.no_grad() for inference
with torch.no_grad():
    prediction = model(x)

Practical Application Examples

1. Simple Linear Regression

python
import torch
import torch.nn as nn

# Generate data
torch.manual_seed(42)
x = torch.randn(100, 1)
y = 3 * x + 2 + 0.1 * torch.randn(100, 1)

# Define parameters
w = torch.randn(1, 1, requires_grad=True)
b = torch.randn(1, requires_grad=True)

# Training
learning_rate = 0.01
for epoch in range(100):
    # Forward pass
    y_pred = x @ w + b
    loss = ((y_pred - y) ** 2).mean()
    
    # Backward pass
    loss.backward()
    
    # Update parameters
    with torch.no_grad():
        w -= learning_rate * w.grad
        b -= learning_rate * b.grad
        
        # Zero gradients
        w.grad.zero_()
        b.grad.zero_()
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

print(f"Learned parameters: w={w.item():.2f}, b={b.item():.2f}")

2. Gradient Flow in Neural Networks

python
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(10, 50),
            nn.ReLU(),
            nn.Linear(50, 20),
            nn.ReLU(),
            nn.Linear(20, 1)
        )
    
    def forward(self, x):
        return self.layers(x)

# Create model and data
model = SimpleNet()
x = torch.randn(32, 10)
y = torch.randn(32, 1)

# Forward pass
output = model(x)
loss = nn.MSELoss()(output, y)

# Backward pass
loss.backward()

# Check gradients for each layer
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name}: gradient norm = {param.grad.norm().item():.6f}")

Summary

Automatic differentiation is a core feature of PyTorch, and understanding it is essential for deep learning:

  1. Computational Graph: Understand dynamic computational graph construction and execution
  2. Gradient Computation: Master gradient computation for scalar and vector functions
  3. Gradient Control: Learn to use no_grad, detach, etc. to control gradients
  4. Performance Optimization: Avoid unnecessary gradient computation, clean up memory in time
  5. Debugging Techniques: Use gradient checking to verify implementation correctness

Mastering these concepts will lay a solid foundation for subsequent neural network training!

Content is for learning and research only.