Skip to content

PyTorch Neural Network Fundamentals

Introduction to torch.nn Module

torch.nn is the core module in PyTorch for building neural networks, providing various components such as layers, activation functions, and loss functions.

python
import torch
import torch.nn as nn
import torch.nn.functional as F

# Basic neural network layers
linear = nn.Linear(10, 5)  # Linear layer: 10 input dimensions, 5 output dimensions
conv = nn.Conv2d(3, 16, 3)  # Convolutional layer: 3 input channels, 16 output channels, 3x3 kernel
lstm = nn.LSTM(10, 20, 2)   # LSTM layer: 10 input dimensions, 20 hidden dimensions, 2 layers

Building Your First Neural Network

1. Using nn.Sequential

python
import torch
import torch.nn as nn

# Simplest approach: Sequential container
model = nn.Sequential(
    nn.Linear(784, 128),    # Input layer to hidden layer
    nn.ReLU(),              # Activation function
    nn.Linear(128, 64),     # Hidden layer
    nn.ReLU(),
    nn.Linear(64, 10)       # Output layer
)

# Test the model
x = torch.randn(32, 784)  # Batch size 32, 784 features
output = model(x)
print(f"Output shape: {output.shape}")  # [32, 10]

2. Custom nn.Module

python
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Create model instance
model = MLP(784, 128, 10)
print(model)

# View model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")

Common Neural Network Layers

1. Linear Layer

python
# Fully connected layer
linear = nn.Linear(in_features=100, out_features=50, bias=True)

# View parameters
print(f"Weight shape: {linear.weight.shape}")  # [50, 100]
print(f"Bias shape: {linear.bias.shape}")    # [50]

# Custom initialization
nn.init.xavier_uniform_(linear.weight)
nn.init.zeros_(linear.bias)

2. Convolutional Layer

python
# 2D convolution
conv2d = nn.Conv2d(
    in_channels=3,      # Input channels
    out_channels=16,    # Output channels
    kernel_size=3,      # Kernel size
    stride=1,           # Stride
    padding=1,          # Padding
    bias=True
)

# 1D convolution (for sequential data)
conv1d = nn.Conv1d(in_channels=10, out_channels=20, kernel_size=3)

# Transposed convolution (deconvolution)
conv_transpose = nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1)

# Test convolution layer
x = torch.randn(32, 3, 64, 64)  # [batch, channels, height, width]
output = conv2d(x)
print(f"Convolution output shape: {output.shape}")  # [32, 16, 64, 64]

3. Pooling Layer

python
# Max pooling
maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

# Average pooling
avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

# Adaptive pooling (fixed output size)
adaptive_pool = nn.AdaptiveAvgPool2d((7, 7))

# Global average pooling
global_pool = nn.AdaptiveAvgPool2d((1, 1))

# Test pooling
x = torch.randn(32, 16, 64, 64)
pooled = maxpool(x)
print(f"After pooling shape: {pooled.shape}")  # [32, 16, 32, 32]

4. Recurrent Layers

python
# LSTM layer
lstm = nn.LSTM(
    input_size=100,     # Input feature dimension
    hidden_size=128,    # Hidden state dimension
    num_layers=2,       # Number of layers
    batch_first=True,   # Input shape is (batch, seq, feature)
    dropout=0.1,        # Dropout probability
    bidirectional=False # Whether to use bidirectional
)

# GRU layer
gru = nn.GRU(input_size=100, hidden_size=128, num_layers=2, batch_first=True)

# Simple RNN
rnn = nn.RNN(input_size=100, hidden_size=128, num_layers=2, batch_first=True)

# Test LSTM
x = torch.randn(32, 50, 100)  # [batch, seq_len, input_size]
output, (hidden, cell) = lstm(x)
print(f"LSTM output shape: {output.shape}")  # [32, 50, 128]
print(f"Hidden state shape: {hidden.shape}")   # [2, 32, 128]

5. Attention Mechanism

python
# Multi-head attention
attention = nn.MultiheadAttention(
    embed_dim=512,      # Embedding dimension
    num_heads=8,        # Number of attention heads
    dropout=0.1,
    batch_first=True
)

# Test attention
query = torch.randn(32, 50, 512)  # [batch, seq_len, embed_dim]
key = torch.randn(32, 50, 512)
value = torch.randn(32, 50, 512)

attn_output, attn_weights = attention(query, key, value)
print(f"Attention output shape: {attn_output.shape}")  # [32, 50, 512]

Activation Functions

python
# Common activation functions
relu = nn.ReLU()
leaky_relu = nn.LeakyReLU(negative_slope=0.01)
elu = nn.ELU()
gelu = nn.GELU()
swish = nn.SiLU()  # Swish activation function
tanh = nn.Tanh()
sigmoid = nn.Sigmoid()
softmax = nn.Softmax(dim=-1)

# Functional interface
x = torch.randn(10)
y1 = F.relu(x)
y2 = F.gelu(x)
y3 = F.softmax(x, dim=0)

print(f"ReLU output: {y1}")
print(f"GELU output: {y2}")
print(f"Softmax output: {y3}")

Regularization Techniques

1. Dropout

python
# Dropout layer
dropout = nn.Dropout(p=0.5)  # 50% of neurons randomly set to zero

# Different behavior during training and evaluation
model.train()  # Training mode, enable dropout
output_train = dropout(x)

model.eval()   # Evaluation mode, disable dropout
output_eval = dropout(x)

2. Batch Normalization

python
# 1D batch normalization (for fully connected layers)
bn1d = nn.BatchNorm1d(num_features=128)

# 2D batch normalization (for convolutional layers)
bn2d = nn.BatchNorm2d(num_features=64)

# Layer normalization
ln = nn.LayerNorm(normalized_shape=128)

# Group normalization
gn = nn.GroupNorm(num_groups=8, num_channels=64)

# Instance normalization
in_norm = nn.InstanceNorm2d(num_features=64)

3. Weight Initialization

python
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

# Apply initialization
model.apply(init_weights)

Complex Network Architectures

1. Residual Connection

python
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Skip connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        residual = self.shortcut(x)
        
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual  # Residual connection
        out = F.relu(out)
        
        return out

2. Attention Mechanism

python
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()
        
        # Compute Q, K, V
        Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)
        
        # Apply attention
        attn_output = torch.matmul(attn_weights, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, embed_dim
        )
        
        return self.out(attn_output)

Model Management

1. Parameter Access and Modification

python
# Access all parameters
for name, param in model.named_parameters():
    print(f"{name}: {param.shape}")

# Access specific layer parameters
linear_layer = model.fc1
print(f"Weight: {linear_layer.weight.shape}")
print(f"Bias: {linear_layer.bias.shape}")

# Freeze parameters
for param in model.parameters():
    param.requires_grad = False

# Only train specific layers
for name, param in model.named_parameters():
    if 'fc3' in name:  # Only train the last layer
        param.requires_grad = True
    else:
        param.requires_grad = False

2. Model State Management

python
# Training and evaluation modes
model.train()  # Enable dropout, batch norm, etc.
model.eval()   # Disable dropout, batch norm, etc.

# Check mode
print(f"Is model in training mode: {model.training}")

# Move to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Check model device
print(f"Model device: {next(model.parameters()).device}")

3. Model Saving and Loading

python
# Save entire model
torch.save(model, 'model.pth')

# Save only parameters (recommended)
torch.save(model.state_dict(), 'model_params.pth')

# Load model
model = torch.load('model.pth')

# Load parameters
model = MLP(784, 128, 10)  # First create model architecture
model.load_state_dict(torch.load('model_params.pth'))

# Save training state
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch,
    'loss': loss
}
torch.save(checkpoint, 'checkpoint.pth')

Practical Application Examples

1. Image Classification Network

python
class ImageClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super(ImageClassifier, self).__init__()
        self.features = nn.Sequential(
            # First convolutional block
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Second convolutional block
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Third convolutional block
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Test model
model = ImageClassifier(num_classes=10)
x = torch.randn(32, 3, 32, 32)  # CIFAR-10 size
output = model(x)
print(f"Classification output shape: {output.shape}")  # [32, 10]

2. Sequence-to-Sequence Model

python
class Seq2Seq(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=2):
        super(Seq2Seq, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # Encoder
        self.encoder = nn.LSTM(input_size, hidden_size, num_layers, 
                              batch_first=True, dropout=0.1)
        
        # Decoder
        self.decoder = nn.LSTM(output_size, hidden_size, num_layers, 
                              batch_first=True, dropout=0.1)
        
        # Output layer
        self.out = nn.Linear(hidden_size, output_size)
        
    def forward(self, src, tgt):
        # Encode
        encoder_output, (hidden, cell) = self.encoder(src)
        
        # Decode
        decoder_output, _ = self.decoder(tgt, (hidden, cell))
        
        # Output
        output = self.out(decoder_output)
        
        return output

Debugging Tips

1. Inspect Model Architecture

python
from torchsummary import summary

# Print model summary
summary(model, input_size=(3, 32, 32))

# Or use torchinfo
from torchinfo import summary
summary(model, input_size=(32, 3, 32, 32))

2. Gradient Check

python
# Check if gradients are normal
def check_gradients(model):
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm()
            print(f"{name}: Gradient norm = {grad_norm:.6f}")
        else:
            print(f"{name}: No gradient")

3. Visualize Network

python
import torch.nn as nn
import matplotlib.pyplot as plt

def visualize_weights(model, layer_name):
    for name, module in model.named_modules():
        if name == layer_name and isinstance(module, nn.Conv2d):
            weights = module.weight.data
            # Visualize first filter
            plt.imshow(weights[0, 0].cpu(), cmap='gray')
            plt.title(f'{layer_name} - First Filter')
            plt.show()
            break

Summary

Neural networks are the core of deep learning, and PyTorch provides rich tools to build various network architectures:

  1. Basic Components: Master various layers, activation functions, and regularization techniques
  2. Model Building: Learn to use Sequential and custom Module
  3. Advanced Architectures: Understand residual connections, attention mechanisms, and modern techniques
  4. Model Management: Master parameter access, state management, and saving/loading
  5. Debugging Tips: Learn to inspect model architecture, gradients, weights, etc.

These knowledge will lay a solid foundation for subsequent model training and optimization!

Content is for learning and research only.