Skip to content

Transformer Model

Transformer is a neural network architecture based on attention mechanisms, proposed by Vaswani et al. in the 2017 paper "Attention Is All You Need". It revolutionized the field of natural language processing and became the cornerstone of modern NLP.

Transformer Basics

What is Transformer?

Transformer is entirely based on attention mechanisms, abandoning traditional recurrent and convolutional structures. It can process all positions in a sequence in parallel, greatly improving training efficiency while excelling at modeling long-distance dependencies.

Core Components

python
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

# Core components of Transformer
class MultiHeadAttention(keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = keras.layers.Dense(d_model)
        self.wk = keras.layers.Dense(d_model)
        self.wv = keras.layers.Dense(d_model)

        self.dense = keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])

        concat_attention = tf.reshape(scaled_attention,
                                    (batch_size, -1, self.d_model))

        output = self.dense(concat_attention)

        return output, attention_weights

def scaled_dot_product_attention(q, k, v, mask):
    """Calculate attention weights"""
    matmul_qk = tf.matmul(q, k, transpose_b=True)

    # Scaling
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # Add mask
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)

    # Softmax
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)

    output = tf.matmul(attention_weights, v)

    return output, attention_weights

Positional Encoding

python
def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
    return pos * angle_rates

def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

    # Apply sin to even indices
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

    # Apply cos to odd indices
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

    pos_encoding = angle_rads[np.newaxis, ...]

    return tf.cast(pos_encoding, dtype=tf.float32)

# Visualize positional encoding
def plot_positional_encoding(pos_encoding):
    plt.figure(figsize=(15, 5))
    plt.pcolormesh(pos_encoding[0], cmap='RdYlBu')
    plt.xlabel('Depth')
    plt.xlim((0, 512))
    plt.ylabel('Position')
    plt.colorbar()
    plt.title('Positional Encoding')
    plt.show()

# Example
pos_encoding = positional_encoding(50, 512)
print(pos_encoding.shape)
# plot_positional_encoding(pos_encoding)

Feed-Forward Network

python
def point_wise_feed_forward_network(d_model, dff):
    return keras.Sequential([
        keras.layers.Dense(dff, activation='relu'),
        keras.layers.Dense(d_model)
    ])

# Example
sample_ffn = point_wise_feed_forward_network(512, 2048)
sample_ffn(tf.random.uniform((64, 50, 512))).shape

Encoder Layer

python
class EncoderLayer(keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()

        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = keras.layers.Dropout(rate)
        self.dropout2 = keras.layers.Dropout(rate)

    def call(self, x, training, mask):
        attn_output, _ = self.mha(x, x, x, mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)

        return out2

# Test Encoder layer
sample_encoder_layer = EncoderLayer(512, 8, 2048)
sample_encoder_layer_output = sample_encoder_layer(
    tf.random.uniform((64, 43, 512)), False, None)
print(sample_encoder_layer_output.shape)

Decoder Layer

python
class DecoderLayer(keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()

        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)

        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = keras.layers.Dropout(rate)
        self.dropout2 = keras.layers.Dropout(rate)
        self.dropout3 = keras.layers.Dropout(rate)

    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        # Self-attention
        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(attn1 + x)

        # Encoder-decoder attention
        attn2, attn_weights_block2 = self.mha2(
            enc_output, enc_output, out1, padding_mask)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(attn2 + out1)

        # Feed-forward network
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(ffn_output + out2)

        return out3, attn_weights_block1, attn_weights_block2

Complete Transformer Model

python
class Encoder(keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
                 maximum_position_encoding, rate=0.1):
        super(Encoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = keras.layers.Embedding(input_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding,
                                               self.d_model)

        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
                          for _ in range(num_layers)]

        self.dropout = keras.layers.Dropout(rate)

    def call(self, x, training, mask):
        seq_len = tf.shape(x)[1]

        # Word embedding and positional encoding
        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x, training, mask)

        return x

class Decoder(keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
                 maximum_position_encoding, rate=0.1):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = keras.layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

        self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate)
                          for _ in range(num_layers)]
        self.dropout = keras.layers.Dropout(rate)

    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        seq_len = tf.shape(x)[1]
        attention_weights = {}

        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                                   look_ahead_mask, padding_mask)

            attention_weights[f'decoder_layer{i+1}_block1'] = block1
            attention_weights[f'decoder_layer{i+1}_block2'] = block2

        return x, attention_weights

class Transformer(keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
                 target_vocab_size, pe_input, pe_target, rate=0.1):
        super(Transformer, self).__init__()

        self.encoder = Encoder(num_layers, d_model, num_heads, dff,
                              input_vocab_size, pe_input, rate)

        self.decoder = Decoder(num_layers, d_model, num_heads, dff,
                              target_vocab_size, pe_target, rate)

        self.final_layer = keras.layers.Dense(target_vocab_size)

    def call(self, inp, tar, training, enc_padding_mask,
             look_ahead_mask, dec_padding_mask):

        enc_output = self.encoder(inp, training, enc_padding_mask)

        dec_output, attention_weights = self.decoder(
            tar, enc_output, training, look_ahead_mask, dec_padding_mask)

        final_output = self.final_layer(dec_output)

        return final_output, attention_weights

Masking Mechanism

python
def create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
    return seq[:, tf.newaxis, tf.newaxis, :]

def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask

def create_masks(inp, tar):
    # Encoder padding mask
    enc_padding_mask = create_padding_mask(inp)

    # Decoder padding mask for second attention block
    dec_padding_mask = create_padding_mask(inp)

    # Decoder look-ahead mask for first attention block
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

    return enc_padding_mask, combined_mask, dec_padding_mask

Training Setup

python
class CustomSchedule(keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

# Learning rate schedule
learning_rate = CustomSchedule(512)

optimizer = keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                 epsilon=1e-9)

# Loss function
loss_object = keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_sum(loss_)/tf.reduce_sum(mask)

# Metrics
train_loss = keras.metrics.Mean(name='train_loss')
train_accuracy = keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

Text Classification Example

python
def create_transformer_classifier(vocab_size, d_model=128, num_heads=8,
                                 num_layers=4, dff=512, max_seq_len=512,
                                 num_classes=2):
    """
    Create Transformer model for text classification
    """
    inputs = keras.layers.Input(shape=(max_seq_len,))

    # Word embedding
    embedding = keras.layers.Embedding(vocab_size, d_model)(inputs)

    # Positional encoding
    pos_encoding = positional_encoding(max_seq_len, d_model)
    embedding += pos_encoding[:, :max_seq_len, :]

    # Transformer encoder layers
    x = embedding
    for _ in range(num_layers):
        # Multi-head self-attention
        attn_output = keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=d_model//num_heads
        )(x, x)

        # Residual connection and layer normalization
        x = keras.layers.LayerNormalization()(x + attn_output)

        # Feed-forward network
        ffn_output = keras.layers.Dense(dff, activation='relu')(x)
        ffn_output = keras.layers.Dense(d_model)(ffn_output)

        # Residual connection and layer normalization
        x = keras.layers.LayerNormalization()(x + ffn_output)

    # Global average pooling
    pooled = keras.layers.GlobalAveragePooling1D()(x)

    # Classification layer
    outputs = keras.layers.Dense(num_classes, activation='softmax')(pooled)

    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

# Create classification model
classifier = create_transformer_classifier(vocab_size=10000)
classifier.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

classifier.summary()

Machine Translation Example

python
def create_translation_model():
    """
    Create machine translation model
    """
    # Model parameters
    num_layers = 4
    d_model = 128
    dff = 512
    num_heads = 8

    input_vocab_size = 8500
    target_vocab_size = 8000

    # Create Transformer
    transformer = Transformer(
        num_layers, d_model, num_heads, dff,
        input_vocab_size, target_vocab_size,
        pe_input=input_vocab_size,
        pe_target=target_vocab_size
    )

    return transformer

# Training step
@tf.function
def train_step(inp, tar, transformer, optimizer):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

    with tf.GradientTape() as tape:
        predictions, _ = transformer(inp, tar_inp,
                                   True,
                                   enc_padding_mask,
                                   combined_mask,
                                   dec_padding_mask)
        loss = loss_function(tar_real, predictions)

    gradients = tape.gradient(loss, transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    train_loss(loss)
    train_accuracy(tar_real, predictions)

BERT-style Pretrained Model

python
def create_bert_style_model(vocab_size, d_model=768, num_heads=12,
                           num_layers=12, dff=3072, max_seq_len=512):
    """
    Create BERT-style pretrained model
    """
    inputs = keras.layers.Input(shape=(max_seq_len,))

    # Word embedding
    embedding = keras.layers.Embedding(vocab_size, d_model)(inputs)

    # Positional encoding
    pos_encoding = positional_encoding(max_seq_len, d_model)
    embedding += pos_encoding[:, :max_seq_len, :]

    # Transformer encoder layers
    x = embedding
    for _ in range(num_layers):
        # Multi-head self-attention
        attn_output = keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=d_model//num_heads
        )(x, x)

        # Dropout and residual connection
        attn_output = keras.layers.Dropout(0.1)(attn_output)
        x = keras.layers.LayerNormalization()(x + attn_output)

        # Feed-forward network
        ffn_output = keras.layers.Dense(dff, activation='gelu')(x)
        ffn_output = keras.layers.Dense(d_model)(ffn_output)
        ffn_output = keras.layers.Dropout(0.1)(ffn_output)

        # Residual connection and layer normalization
        x = keras.layers.LayerNormalization()(x + ffn_output)

    # Output heads for different tasks
    # MLM head (masked language model)
    mlm_output = keras.layers.Dense(vocab_size)(x)

    # Classification head (for sentence-level tasks)
    cls_output = keras.layers.Dense(2, activation='softmax')(x[:, 0, :])

    model = keras.Model(inputs=inputs, outputs=[mlm_output, cls_output])
    return model

Performance Optimization Tips

python
# 1. Use mixed precision training
policy = keras.mixed_precision.Policy('mixed_float16')
keras.mixed_precision.set_global_policy(policy)

# 2. Gradient accumulation
def train_step_with_accumulation(inp, tar, transformer, optimizer, accumulation_steps=4):
    accumulated_gradients = []

    for i in range(accumulation_steps):
        with tf.GradientTape() as tape:
            # Forward propagation
            predictions, _ = transformer(inp[i], tar[i], True, None, None, None)
            loss = loss_function(tar[i][:, 1:], predictions) / accumulation_steps

        # Calculate gradients
        gradients = tape.gradient(loss, transformer.trainable_variables)

        if i == 0:
            accumulated_gradients = gradients
        else:
            accumulated_gradients = [acc_grad + grad for acc_grad, grad in
                                   zip(accumulated_gradients, gradients)]

    # Apply accumulated gradients
    optimizer.apply_gradients(zip(accumulated_gradients, transformer.trainable_variables))

# 3. Learning rate warmup
def warmup_cosine_decay(step, total_steps, warmup_steps, max_lr):
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    else:
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return max_lr * 0.5 * (1 + tf.cos(np.pi * progress))

Attention Visualization

python
def plot_attention_weights(attention, sentence, result, layer):
    """
    Visualize attention weights
    """
    fig = plt.figure(figsize=(16, 8))

    sentence = sentence[0]

    attention = tf.squeeze(attention[layer], axis=0)

    for head in range(attention.shape[0]):
        ax = fig.add_subplot(2, 4, head+1)

        # Draw attention weights
        ax.matshow(attention[head][:-1, :], cmap='Blues')

        fontdict = {'fontsize': 10}

        ax.set_xticks(range(len(sentence)+2))
        ax.set_yticks(range(len(result)))

        ax.set_ylim(len(result)-1.5, -0.5)

        ax.set_xticklabels(
            ['<start>']+[tokenizer_pt.decode([i]) for i in sentence]+['<end>'],
            fontdict=fontdict, rotation=90)

        ax.set_yticklabels([tokenizer_en.decode([i]) for i in result
                           if i < tokenizer_en.vocab_size],
                          fontdict=fontdict)

        ax.set_xlabel('Head {}'.format(head+1))

    plt.tight_layout()
    plt.show()

Practical Application Recommendations

1. Model Selection

  • Small datasets: Use pretrained model fine-tuning
  • Large datasets: Train from scratch or continue pretraining
  • Real-time applications: Consider model compression and optimization

2. Hyperparameter Tuning

  • Learning rate: Use warmup and decay
  • Batch size: Adjust based on GPU memory
  • Number of layers and dimensions: Balance performance and computational cost

3. Data Processing

  • Appropriate sequence length
  • Data augmentation techniques
  • Vocabulary size optimization

Summary

Transformer models have revolutionized deep learning, especially in the NLP field. Their parallelization capability, long-distance dependency modeling, and interpretability make them core components of modern AI systems. Understanding Transformer principles and implementation is crucial for deep learning practitioners.

Next chapter we will learn Generative Adversarial Networks (GAN) and explore another important branch of generative models.

Content is for learning and research only.