Skip to content

PyTorch Transformer Models

Introduction to Transformer

Transformer is a neural network architecture based on the self-attention mechanism, proposed by Vaswani et al. in 2017. It has revolutionized the field of natural language processing and has become the foundational architecture for large language models like BERT and GPT.

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt

# Core components of Transformer
multihead_attn = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
transformer_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
transformer = nn.TransformerEncoder(transformer_layer, num_layers=6)

Positional Encoding

1. Sinusoidal Positional Encoding

python
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Calculate division term
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        # Apply sine and cosine functions
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension and register as buffer
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x shape: (seq_len, batch_size, d_model)
        return x + self.pe[:x.size(0), :]

2. Learnable Positional Encoding

python
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(LearnablePositionalEncoding, self).__init__()
        self.pe = nn.Parameter(torch.randn(max_len, d_model))
    
    def forward(self, x):
        seq_len = x.size(0)
        return x + self.pe[:seq_len, :].unsqueeze(1)

Multi-Head Attention Mechanism

1. Self-Attention Implementation

python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear transformation layers
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Q, K, V shape: (batch_size, num_heads, seq_len, d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        context = torch.matmul(attention_weights, V)
        return context, attention_weights
    
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, d_model = query.size()
        
        # Linear transformation and reshape to multi-head
        Q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Apply attention
        context, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Reshape and concatenate multi-heads
        context = context.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model
        )
        
        # Output projection
        output = self.w_o(context)
        
        return output, attention_weights

Transformer Encoder

1. Encoder Layer

python
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        
        # Multi-head attention
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Feed-forward network
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention + residual connection + layer normalization
        attn_output, attention_weights = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward network + residual connection + layer normalization
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x, attention_weights

2. Complete Encoder

python
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        
        # Word embedding
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        # Encoder layers
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model
    
    def forward(self, x, mask=None):
        # Word embedding + positional encoding
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = x.transpose(0, 1)  # (seq_len, batch_size, d_model)
        x = self.pos_encoding(x)
        x = x.transpose(0, 1)  # (batch_size, seq_len, d_model)
        x = self.dropout(x)
        
        # Through encoder layers
        attention_weights = []
        for layer in self.layers:
            x, attn_weights = layer(x, mask)
            attention_weights.append(attn_weights)
        
        return x, attention_weights

Practical Application Example

1. Text Classification Transformer

python
class TextClassificationTransformer(nn.Module):
    def __init__(self, vocab_size, num_classes, d_model=512, num_heads=8, num_layers=6):
        super(TextClassificationTransformer, self).__init__()
        
        self.encoder = TransformerEncoder(
            vocab_size, d_model, num_heads, num_layers, d_ff=2048, max_len=5000
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, num_classes)
        )
    
    def forward(self, x, mask=None):
        # Encode
        encoder_output, attention_weights = self.encoder(x, mask)
        
        # Global average pooling
        if mask is not None:
            mask = mask.squeeze(1).squeeze(1)  # (batch_size, seq_len)
            masked_output = encoder_output * mask.unsqueeze(-1).float()
            pooled = masked_output.sum(dim=1) / mask.sum(dim=1, keepdim=True).float()
        else:
            pooled = encoder_output.mean(dim=1)
        
        # Classification
        logits = self.classifier(pooled)
        
        return logits, attention_weights

Summary

Transformer models represent a major breakthrough in modern deep learning. This chapter introduced:

  1. Core Components: Positional encoding, multi-head attention, encoder architecture
  2. Complete Implementation: From basic components to complete Transformer models
  3. Practical Applications: Specific tasks like text classification
  4. Training Techniques: Optimization methods like learning rate scheduling and label smoothing

Mastering Transformer will lay a solid foundation for understanding and using modern large language models!

Content is for learning and research only.