PyTorch Transformer Models
Introduction to Transformer
Transformer is a neural network architecture based on the self-attention mechanism, proposed by Vaswani et al. in 2017. It has revolutionized the field of natural language processing and has become the foundational architecture for large language models like BERT and GPT.
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt
# Core components of Transformer
multihead_attn = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
transformer_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
transformer = nn.TransformerEncoder(transformer_layer, num_layers=6)Positional Encoding
1. Sinusoidal Positional Encoding
python
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
# Create positional encoding matrix
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# Calculate division term
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
# Apply sine and cosine functions
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Add batch dimension and register as buffer
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
# x shape: (seq_len, batch_size, d_model)
return x + self.pe[:x.size(0), :]2. Learnable Positional Encoding
python
class LearnablePositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(LearnablePositionalEncoding, self).__init__()
self.pe = nn.Parameter(torch.randn(max_len, d_model))
def forward(self, x):
seq_len = x.size(0)
return x + self.pe[:seq_len, :].unsqueeze(1)Multi-Head Attention Mechanism
1. Self-Attention Implementation
python
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear transformation layers
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
# Q, K, V shape: (batch_size, num_heads, seq_len, d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
context = torch.matmul(attention_weights, V)
return context, attention_weights
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
# Linear transformation and reshape to multi-head
Q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Apply attention
context, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# Reshape and concatenate multi-heads
context = context.transpose(1, 2).contiguous().view(
batch_size, seq_len, d_model
)
# Output projection
output = self.w_o(context)
return output, attention_weightsTransformer Encoder
1. Encoder Layer
python
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(TransformerEncoderLayer, self).__init__()
# Multi-head attention
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
# Feed-forward network
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
# Layer normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention + residual connection + layer normalization
attn_output, attention_weights = self.self_attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed-forward network + residual connection + layer normalization
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x, attention_weights2. Complete Encoder
python
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len, dropout=0.1):
super(TransformerEncoder, self).__init__()
# Word embedding
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
# Encoder layers
self.layers = nn.ModuleList([
TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.dropout = nn.Dropout(dropout)
self.d_model = d_model
def forward(self, x, mask=None):
# Word embedding + positional encoding
x = self.embedding(x) * math.sqrt(self.d_model)
x = x.transpose(0, 1) # (seq_len, batch_size, d_model)
x = self.pos_encoding(x)
x = x.transpose(0, 1) # (batch_size, seq_len, d_model)
x = self.dropout(x)
# Through encoder layers
attention_weights = []
for layer in self.layers:
x, attn_weights = layer(x, mask)
attention_weights.append(attn_weights)
return x, attention_weightsPractical Application Example
1. Text Classification Transformer
python
class TextClassificationTransformer(nn.Module):
def __init__(self, vocab_size, num_classes, d_model=512, num_heads=8, num_layers=6):
super(TextClassificationTransformer, self).__init__()
self.encoder = TransformerEncoder(
vocab_size, d_model, num_heads, num_layers, d_ff=2048, max_len=5000
)
# Classification head
self.classifier = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_model // 2, num_classes)
)
def forward(self, x, mask=None):
# Encode
encoder_output, attention_weights = self.encoder(x, mask)
# Global average pooling
if mask is not None:
mask = mask.squeeze(1).squeeze(1) # (batch_size, seq_len)
masked_output = encoder_output * mask.unsqueeze(-1).float()
pooled = masked_output.sum(dim=1) / mask.sum(dim=1, keepdim=True).float()
else:
pooled = encoder_output.mean(dim=1)
# Classification
logits = self.classifier(pooled)
return logits, attention_weightsSummary
Transformer models represent a major breakthrough in modern deep learning. This chapter introduced:
- Core Components: Positional encoding, multi-head attention, encoder architecture
- Complete Implementation: From basic components to complete Transformer models
- Practical Applications: Specific tasks like text classification
- Training Techniques: Optimization methods like learning rate scheduling and label smoothing
Mastering Transformer will lay a solid foundation for understanding and using modern large language models!