PyTorch Automatic Differentiation
What is Automatic Differentiation?
Automatic Differentiation (AutoGrad) is one of PyTorch's core features. It can automatically compute gradients for tensor operations, which is crucial for the backpropagation algorithm in deep learning.
python
import torch
# Create tensor requiring gradient
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2 + 3 * x + 1
# Compute gradients
y.backward()
print(f"x = {x}")
print(f"y = {y}")
print(f"dy/dx = {x.grad}") # Should be 2*2 + 3 = 7Computational Graph
PyTorch uses dynamic computational graphs to track operations and compute gradients:
python
import torch
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)
# Build computational graph
z = x * y + x ** 2
w = z.mean()
print(f"Functions in computational graph: {w.grad_fn}")
print(f"Gradient function for z: {z.grad_fn}")
# Backpropagation
w.backward()
print(f"dw/dx = {x.grad}") # 2*x + y = 2*1 + 2 = 4
print(f"dw/dy = {y.grad}") # x = 1requires_grad Attribute
1. Basic Usage
python
# Specify during creation
x = torch.randn(3, 4, requires_grad=True)
# Set later
y = torch.randn(3, 4)
y.requires_grad_(True) # In-place modification
# Check if gradient is needed
print(f"x requires grad: {x.requires_grad}")
print(f"y requires grad: {y.requires_grad}")2. Gradient Propagation Rules
python
x = torch.randn(2, 2, requires_grad=True)
y = torch.randn(2, 2, requires_grad=False)
# If one operand requires gradient, result requires gradient
z = x + y
print(f"z requires grad: {z.requires_grad}") # True
# When all operands don't require gradient, result doesn't either
a = torch.randn(2, 2)
b = torch.randn(2, 2)
c = a + b
print(f"c requires grad: {c.requires_grad}") # FalseGradient Computation
1. Gradients for Scalar Functions
python
# Single variable function
x = torch.tensor(3.0, requires_grad=True)
y = x ** 3 - 2 * x ** 2 + x - 1
y.backward()
print(f"dy/dx = {x.grad}") # 3*9 - 4*3 + 1 = 16
# Multi-variable function
x = torch.tensor(1.0, requires_grad=True)
y = torch.tensor(2.0, requires_grad=True)
z = x ** 2 + y ** 2 + 2 * x * y
z.backward()
print(f"dz/dx = {x.grad}") # 2*x + 2*y = 2*1 + 2*2 = 6
print(f"dz/dy = {y.grad}") # 2*y + 2*x = 2*2 + 2*1 = 62. Gradients for Vector Functions
python
# For non-scalar output, need to provide gradient parameter
x = torch.randn(3, requires_grad=True)
y = x * 2
# Need to provide gradient of same shape as y
gradient = torch.ones_like(y)
y.backward(gradient)
print(f"x gradients: {x.grad}") # Should all be 23. Jacobian-Vector Product
python
x = torch.randn(3, requires_grad=True)
y = x ** 2
# Compute Jacobian-vector product J^T * v
v = torch.tensor([1.0, 1.0, 1.0])
y.backward(v)
print(f"Jacobian-vector product: {x.grad}") # 2*xGradient Accumulation
python
x = torch.tensor(1.0, requires_grad=True)
# First computation
y1 = x ** 2
y1.backward()
print(f"First gradient: {x.grad}") # 2
# Second computation (gradients accumulate)
y2 = x ** 3
y2.backward()
print(f"Accumulated gradient: {x.grad}") # 2 + 3 = 5
# Zero gradients
x.grad.zero_()
print(f"After zeroing gradient: {x.grad}") # 0Higher-Order Derivatives
python
# Compute second derivative
x = torch.tensor(2.0, requires_grad=True)
y = x ** 4
# First derivative
grad1 = torch.autograd.grad(y, x, create_graph=True)[0]
print(f"First derivative: {grad1}") # 4 * x^3 = 32
# Second derivative
grad2 = torch.autograd.grad(grad1, x)[0]
print(f"Second derivative: {grad2}") # 12 * x^2 = 48Controlling Gradient Computation
1. torch.no_grad()
python
x = torch.randn(3, requires_grad=True)
# In no_grad context, computational graph is not built
with torch.no_grad():
y = x ** 2
print(f"y requires grad: {y.requires_grad}") # False
# Decorator form
@torch.no_grad()
def inference(x):
return x ** 2 + 1
result = inference(x)
print(f"Inference result requires grad: {result.requires_grad}") # False2. detach() Method
python
x = torch.randn(3, requires_grad=True)
y = x ** 2
# Detach tensor, stop gradient flow
y_detached = y.detach()
z = y_detached * 2
print(f"y requires grad: {y.requires_grad}") # True
print(f"y_detached requires grad: {y_detached.requires_grad}") # False
print(f"z requires grad: {z.requires_grad}") # False3. torch.set_grad_enabled()
python
# Global gradient control
torch.set_grad_enabled(False)
x = torch.randn(3, requires_grad=True)
y = x ** 2
print(f"y requires grad with global disabled: {y.requires_grad}") # False
torch.set_grad_enabled(True) # Re-enableCustom autograd Function
python
class MySquare(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# Save input for backpropagation
ctx.save_for_backward(input)
return input ** 2
@staticmethod
def backward(ctx, grad_output):
# Get saved input
input, = ctx.saved_tensors
# Compute gradient
return grad_output * 2 * input
# Use custom function
my_square = MySquare.apply
x = torch.tensor(3.0, requires_grad=True)
y = my_square(x)
y.backward()
print(f"Custom function gradient: {x.grad}") # 6Gradient Checking
python
def gradient_check(func, inputs, eps=1e-6):
"""Numerical gradient check"""
# Parse gradients
outputs = func(*inputs)
if outputs.numel() != 1:
outputs = outputs.sum()
analytical_grads = torch.autograd.grad(outputs, inputs)
# Numerical gradients
numerical_grads = []
for i, inp in enumerate(inputs):
grad = torch.zeros_like(inp)
it = torch.nditer(inp.detach().numpy(), flags=['multi_index'])
while not it.finished:
idx = it.multi_index
old_value = inp[idx].item()
# f(x + eps)
inp[idx] = old_value + eps
pos_output = func(*inputs).sum()
# f(x - eps)
inp[idx] = old_value - eps
neg_output = func(*inputs).sum()
# Numerical gradient
grad[idx] = (pos_output - neg_output) / (2 * eps)
# Restore original value
inp[idx] = old_value
it.iternext()
numerical_grads.append(grad)
# Compare gradients
for i, (analytical, numerical) in enumerate(zip(analytical_grads, numerical_grads)):
diff = torch.abs(analytical - numerical).max()
print(f"Gradient difference for input {i}: {diff:.8f}")
# Test
def test_func(x, y):
return x ** 2 + y ** 3
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
gradient_check(test_func, [x, y])Common Problems and Solutions
1. Gradient Explosion
python
# Gradient clipping
def clip_gradients(model, max_norm):
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# Usage example
# clip_gradients(model, max_norm=1.0)2. Gradient Vanishing
python
# Check gradient magnitude
def check_gradients(model):
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1. / 2)
print(f"Total gradient norm: {total_norm}")3. Memory Leaks
python
# Clean up gradients in time
optimizer.zero_grad()
# Detach tensors that don't need gradients
prediction = model(x).detach()
# Use torch.no_grad() for inference
with torch.no_grad():
prediction = model(x)Practical Application Examples
1. Simple Linear Regression
python
import torch
import torch.nn as nn
# Generate data
torch.manual_seed(42)
x = torch.randn(100, 1)
y = 3 * x + 2 + 0.1 * torch.randn(100, 1)
# Define parameters
w = torch.randn(1, 1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
# Training
learning_rate = 0.01
for epoch in range(100):
# Forward pass
y_pred = x @ w + b
loss = ((y_pred - y) ** 2).mean()
# Backward pass
loss.backward()
# Update parameters
with torch.no_grad():
w -= learning_rate * w.grad
b -= learning_rate * b.grad
# Zero gradients
w.grad.zero_()
b.grad.zero_()
if epoch % 20 == 0:
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
print(f"Learned parameters: w={w.item():.2f}, b={b.item():.2f}")2. Gradient Flow in Neural Networks
python
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(10, 50),
nn.ReLU(),
nn.Linear(50, 20),
nn.ReLU(),
nn.Linear(20, 1)
)
def forward(self, x):
return self.layers(x)
# Create model and data
model = SimpleNet()
x = torch.randn(32, 10)
y = torch.randn(32, 1)
# Forward pass
output = model(x)
loss = nn.MSELoss()(output, y)
# Backward pass
loss.backward()
# Check gradients for each layer
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: gradient norm = {param.grad.norm().item():.6f}")Summary
Automatic differentiation is a core feature of PyTorch, and understanding it is essential for deep learning:
- Computational Graph: Understand dynamic computational graph construction and execution
- Gradient Computation: Master gradient computation for scalar and vector functions
- Gradient Control: Learn to use no_grad, detach, etc. to control gradients
- Performance Optimization: Avoid unnecessary gradient computation, clean up memory in time
- Debugging Techniques: Use gradient checking to verify implementation correctness
Mastering these concepts will lay a solid foundation for subsequent neural network training!