Transformer Mascot

Building a Transformer with AI

An educational journey through the architecture that powers modern AI

Learn how transformers work by building one in PyTorch, from attention mechanisms to complete text generation

What is a Transformer?

What is a transformer? A transformer is a type of neural network architecture introduced in the landmark paper "Attention is All You Need" (Vaswani et al., 2017). It revolutionized artificial intelligence and is now the foundation of virtually all modern large language models, including GPT, BERT, Claude, and many others.

What makes transformers special? Previous approaches to language modeling used recurrent neural networks (RNNs), which process text one word at a time in sequence—like reading a sentence from left to right. Transformers instead use a mechanism called attention that allows them to process all words simultaneously while still understanding their relationships. This parallel processing makes them much faster to train and more effective at capturing long-range dependencies in text.

Learning Path

Step 1: Token Embeddings & Positional Encoding

Token Embeddings

What are tokens? Before we can process text with a neural network, we need to break it into pieces called tokens. A token might be a word ("hello"), a subword ("ing"), or even a single character. For example, the sentence "The cat sat" might be tokenized as ["The", "cat", "sat"], and each token gets assigned a unique number (ID) from a vocabulary—perhaps "The"=5, "cat"=142, "sat"=89.

Why do we need embeddings? Computers can't directly understand these token IDs—they're just arbitrary numbers. We need to convert them into meaningful representations that capture semantic relationships. That's where embeddings come in.

What is an embedding? An embedding is a learned vector representation (a list of numbers) for each token. Instead of representing "cat" as the ID 142, we represent it as a dense vector like [0.2, -0.5, 0.8, ...] with d_model dimensions (typically 512 or 768). These vectors are learned during training so that similar words end up with similar vectors.

Think of this as giving each word a unique coordinate in a high-dimensional space. Words with similar meanings (like "cat" and "kitten") end up close together, while unrelated words (like "cat" and "democracy") are far apart.

class TokenEmbedding(nn.Module):
    """Convert token indices to dense vectors."""

    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        # x: (batch, seq_len) - token indices
        # returns: (batch, seq_len, d_model) - embeddings
        return self.embedding(x)

Positional Encoding: Three Modern Approaches

Why do we need positional information? Consider the sentences "The cat ate the mouse" vs "The mouse ate the cat"—same words, completely different meanings! The order matters. Traditional recurrent neural networks (RNNs) process words one at a time in sequence, so they naturally know the order. But transformers process all tokens simultaneously in parallel (which is faster), so they have no inherent notion of position.

The solution: Position encoding. We need to give the model information about where each token appears in the sequence. Modern transformers use three main approaches, in order from simplest to most complex:

Approach 1: ALiBi (Attention with Linear Biases) — Our Default! 🎯

The simplest and most effective! Instead of modifying embeddings or rotating vectors, ALiBi just adds distance-based penalties directly to attention scores. Brilliantly simple:

attention_score[i,j] = Q·K / √d_k - slope × |i - j|

What this means: When position i attends to position j, we subtract a penalty based on their distance. The further apart they are, the more negative the penalty → lower attention!

Example: Position 5 looking at the sequence:

  • Position 5 (current): distance = 0 → penalty = 0 → full attention
  • Position 4 (1 away): distance = 1 → penalty = -0.25 → slight reduction
  • Position 3 (2 away): distance = 2 → penalty = -0.50 → moderate reduction
  • Position 0 (5 away): distance = 5 → penalty = -1.25 → strong reduction

Multiple heads with different "zoom levels": Each attention head gets a different slope value, creating heads that focus at different ranges:

  • Head 0 (slope = 0.25): Strong penalties → focuses on nearby tokens
  • Head 1 (slope = 0.0625): Moderate penalties → medium-range focus
  • Head 2 (slope = 0.016): Gentle penalties → long-range focus
  • Head 3 (slope = 0.004): Very gentle → very long-range relationships
class ALiBiPositionalBias(nn.Module):
    """ALiBi: The simplest modern position encoding."""

    def forward(self, seq_len):
        # Compute pairwise distances: |i - j|
        distances = torch.abs(positions.T - positions)

        # Apply slope to get biases: -slope × distance
        biases = -slopes * distances

        # Added to attention scores before softmax!
        return biases  # (num_heads, seq_len, seq_len)

Approach 2: Learned Positional Embeddings (GPT-2, BERT)

How it works: We create special "position vectors" that are added to token embeddings. Each position gets its own learnable embedding—position 0 has one vector, position 1 has another, and so on. These are learned during training, just like token embeddings.

We add (not concatenate) these position embeddings to the token embeddings, so each token now carries information about both what it is and where it is in the sequence.

class PositionalEncoding(nn.Module):
    """Learned positional embeddings (GPT-2 style)."""

    def __init__(self, d_model, max_seq_len=5000):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        batch_size, seq_len, d_model = x.shape

        # Create position indices: [0, 1, 2, ..., seq_len-1]
        positions = torch.arange(seq_len, device=x.device)

        # Get position embeddings and ADD to input
        pos_emb = self.pos_embedding(positions)
        return x + pos_emb  # Encodes absolute position

Approach 3: RoPE (Rotary Position Embeddings) — Also Excellent

The breakthrough idea: Instead of adding position information to embeddings, we rotate the query and key vectors by an angle proportional to their position. This is now the standard approach in 2024!

The clock analogy: Imagine each token as a hand on a clock. Position 0 points at 12 o'clock. Position 1 rotates to 1 o'clock. Position 2 rotates to 2 o'clock. When two tokens "meet" in attention, the angle between them automatically tells you their relative distance!

class RotaryPositionalEmbedding(nn.Module):
    """RoPE: Rotary Position Embeddings (modern standard)."""

    def forward(self, q, k, position):
        # Instead of adding, we ROTATE q and k by position angle
        # q, k: (batch, num_heads, seq_len, head_dim)

        # Split dimensions into pairs and rotate each pair
        # Rotation encodes position through geometry!
        q_rotated = apply_rotation(q, position)
        k_rotated = apply_rotation(k, position)

        # When q_rotated @ k_rotated, relative position emerges!
        return q_rotated, k_rotated  # Encodes relative position

The Math (Simplified): For vectors at positions m and n:

  • Rotate query q at position m by angle m×θ
  • Rotate key k at position n by angle n×θ
  • When computing q·k, the result depends on (m-n), the relative distance!
  • This is the "angle difference" property of rotations

Different frequency bands allow the model to capture both fine-grained local patterns (adjacent words like "the cat") and long-range dependencies (distant references like "the cat ... it").

🎯 ALiBi (Our Default!)

Parameters: 0 (pure math!)

Position Type: Relative

Extrapolation: ✅✅ BEST!

Simplicity: Easiest

Used in: BLOOM, MPT (2022-2024)

⭐ RoPE

Parameters: 0 (pure math!)

Position Type: Relative

Extrapolation: ✅ Excellent

Simplicity: Moderate

Used in: LLaMA, Mistral (2023-2024)

📊 Learned

Parameters: 1.28M+

Position Type: Absolute

Extrapolation: ❌ Limited

Simplicity: Simple

Used in: GPT-2, GPT-3 (2018-2020)

Why ALiBi is our default:

  • Simplest to understand: Just subtract distance! No complex rotation math or embedding layers.
  • BEST extrapolation: Benchmarks show ALiBi handles extreme length changes better than RoPE or learned. Train on 512, test on 10,000+ tokens!
  • Zero parameters: Like RoPE, purely mathematical. No weights to learn = faster training, better generalization.
  • Different heads, different ranges: Geometric slope sequence gives heads natural "zoom levels" from local to long-range.
  • Proven in production: Powers BLOOM (176B params), MPT, and Falcon models.

When to use alternatives:

  • RoPE: Also excellent! Use for LLaMA-style models or if you prefer rotation-based encoding.
  • Learned: Only for GPT-2/GPT-3 reproduction or educational comparison of historical approaches.

📝 Implementation: src/transformer/embeddings.py

Step 2: Scaled Dot-Product Attention

The core innovation of transformers: Attention is the mechanism that allows each word to "look at" and gather information from all other words in the sentence. This is what makes transformers so powerful.

A concrete example: Consider the sentence "The animal didn't cross the street because it was too tired." What does "it" refer to? A human knows "it" refers to "the animal" (not "the street"). Attention allows the model to learn this—when processing "it", the model can attend strongly to "animal" and incorporate that information.

How does attention work? The mechanism uses three components for each token, derived from the input embeddings through learned linear transformations:

  • Query (Q): "What am I looking for?" - represents what the current token wants to know
  • Key (K): "What do I contain?" - represents what information each token offers
  • Value (V): "What information do I have?" - the actual content that gets passed along

The process: For each token, we compare its Query against all Keys (using dot products) to compute attention scores—how much should we pay attention to each other token? We normalize these scores with softmax to get probabilities (weights that sum to 1), then use these weights to take a weighted average of all Values.

The formula is elegant: Attention(Q, K, V) = softmax(Q·Kᵀ / √d_k) · V

The division by √d_k is a scaling factor that prevents very large dot products in high dimensions, which would cause the softmax to produce near-zero gradients and make training difficult.

Attention mechanism flow diagram showing Query, Key, and Value transformations
def forward(self, query, key, value, mask=None):
    # Get dimension for scaling
    d_k = query.size(-1)

    # Compute attention scores: Q·Kᵀ / √d_k
    scores = torch.matmul(query, key.transpose(-2, -1))
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    # Apply causal mask if provided (prevent looking at future)
    if mask is not None:
        scores = scores.masked_fill(mask == 1, float('-inf'))

    # Apply softmax to get attention weights (probabilities)
    attention_weights = torch.softmax(scores, dim=-1)

    # Apply attention weights to values
    output = torch.matmul(attention_weights, value)

    return output, attention_weights

Causal masking for language models: In decoder-only transformers (like GPT), we add a mask to prevent tokens from attending to future positions. This is essential for autoregressive generation—when predicting the next word, the model shouldn't "cheat" by looking ahead! The mask sets future attention scores to -∞ before softmax, making those positions receive zero attention weight.

📝 Implementation: src/transformer/attention.py

Step 3: Multi-Head Attention

Why multiple heads? A single attention mechanism is powerful, but it can only learn one way of relating tokens. Multi-head attention runs several attention mechanisms in parallel (typically 8 or 16), each called a "head." This gives the model multiple "perspectives" to understand relationships between words.

What do different heads learn? Through training, different heads naturally specialize in different types of relationships. Research shows that real models develop heads that focus on:

  • Syntactic relationships: One head might track subject-verb agreement
  • Semantic relationships: Another head might connect related concepts
  • Long-range dependencies: A head might link pronouns to their antecedents
  • Local patterns: Another head might attend to adjacent words in phrases

How it works: We split the d_model dimensions across heads. With d_model=512 and 8 heads, each head operates on 64 dimensions (512/8). All heads process the input in parallel, then we concatenate their outputs and apply a final linear transformation.

Multi-head attention architecture showing parallel attention heads
def forward(self, x, mask=None):
    batch_size, seq_len, d_model = x.shape

    # 1. Project input to Q, K, V
    Q = self.W_q(x)  # (batch, seq_len, d_model)
    K = self.W_k(x)
    V = self.W_v(x)

    # 2. Split into multiple heads
    # Reshape: (batch, seq_len, num_heads, d_k)
    Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

    # 3. Apply attention to each head (in parallel!)
    output, attn_weights = self.attention(Q, K, V, mask)

    # 4. Concatenate heads back together
    output = output.transpose(1, 2).contiguous()
    output = output.view(batch_size, seq_len, d_model)

    # 5. Final linear projection
    output = self.W_o(output)

    return output

Why don't heads learn the same thing? Different random initializations, different learned projections, and the optimization process all encourage diversity. Redundancy doesn't help reduce loss, so heads naturally specialize.

📝 Implementation: src/transformer/attention.py

Step 4: Position-Wise Feed-Forward Networks

What is a feed-forward network? After attention gathers information from across the sequence, we need to actually process that information. The feed-forward network (FFN) is a simple two-layer neural network—also called a Multi-Layer Perceptron (MLP)—that transforms each token's representation independently.

Why do we need it? Think of the attention layer as "communication" between tokens—gathering relevant context. The FFN is the "computation" step—processing that gathered information to extract useful features and patterns. Without the FFN, the model would only shuffle information around without transforming it.

The architecture:

  1. Expand: Project from d_model (e.g., 512) to d_ff (typically 4× larger, e.g., 2048). This expansion gives the model more "capacity" to learn complex patterns.
  2. Activate: Apply GELU activation—a smooth nonlinear function that allows the model to learn non-linear relationships. Without this nonlinearity, stacking layers would be pointless (multiple linear transformations collapse to one).
  3. Project back: Compress back down from d_ff to d_model so the output shape matches the input, allowing us to stack more layers.

Position-wise: Crucially, the same FFN (same weights) is applied to every position independently. This is efficient and helps the model learn general transformations that work regardless of position.

class FeedForward(nn.Module):
    """Position-wise feed-forward network."""

    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()

        # Expand dimension
        self.linear1 = nn.Linear(d_model, d_ff)

        # GELU activation (used in GPT-2, GPT-3)
        self.activation = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)

        # Project back to d_model
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = self.linear1(x)        # → (batch, seq_len, d_ff)
        x = self.activation(x)
        x = self.dropout1(x)
        x = self.linear2(x)        # → (batch, seq_len, d_model)
        x = self.dropout2(x)
        return x

Division of labor: Attention answers "What should I pay attention to?" while the FFN answers "Now that I have this information, what should I do with it?"

📝 Implementation: src/transformer/feedforward.py

Step 5: Transformer Block

Bringing it all together: A transformer block combines all our components into one repeatable unit. The full transformer model is just many of these blocks stacked on top of each other (GPT-3 has 96 blocks!).

What's in a block? Each block contains four key components:

  • Multi-head attention: Communication layer—tokens gather information from other tokens
  • Feed-forward network: Computation layer—each token processes its gathered information
  • Layer normalization: Stabilizes training by normalizing activations (prevents them from growing too large or small)
  • Residual connections: "Skip connections" that create gradient highways for training deep networks

Pre-LN architecture: We use the Pre-LN (Pre-Layer Normalization) approach used in modern models like GPT-2 and GPT-3. This means we apply layer normalization before each sub-layer (attention or FFN) rather than after. This makes training more stable, especially for very deep networks.

Transformer block architecture with residual connections
def forward(self, x, mask=None):
    # First sub-layer: Multi-head attention with residual
    residual = x
    x = self.norm1(x)                    # Pre-LN
    x = self.attention(x, mask=mask)
    x = self.dropout1(x)
    x = x + residual                     # Residual connection

    # Second sub-layer: Feed-forward with residual
    residual = x
    x = self.norm2(x)                    # Pre-LN
    x = self.ffn(x)
    x = self.dropout2(x)
    x = x + residual                     # Residual connection

    return x

Why residual connections? They create gradient "highways" that allow gradients to flow directly from the output back to early layers. Without them, deep networks struggle to learn: ∂(x + f(x))/∂x = 1 + ∂f(x)/∂x—the "1" ensures gradients always flow!

📝 Implementation: src/transformer/block.py

Step 6: The Complete Transformer

The complete picture: We now assemble all our components into a working decoder-only transformer (GPT-style). This is a complete language model that can be trained to predict the next word in a sequence.

What is "decoder-only"? The original transformer paper had both an encoder (for reading input) and decoder (for generating output), used for translation. Modern language models like GPT use only the decoder part, which is simpler and works great for text generation. The key difference is that decoder-only models use causal masking—they can only look at previous tokens, not future ones.

How data flows through the model:

  1. Token Embedding: Convert input token IDs (integers) to dense vectors
  2. Positional Encoding: Add position information to tell the model where each token is
  3. Transformer Blocks (×N): Stack multiple identical blocks (we use 6; GPT-3 uses 96). Each block refines the representations through attention and feed-forward processing
  4. Final LayerNorm: One last normalization to stabilize the final outputs
  5. Output Projection: Project from d_model dimensions to vocabulary size, giving us scores (logits) for every possible next token

What are logits? The model outputs "logits"—raw, unnormalized scores for each token in the vocabulary. Higher scores mean the model thinks that token is more likely to come next. We can convert these to probabilities using softmax, then either pick the highest (greedy decoding) or sample from the distribution (for more creative generation).

def __init__(
    self, vocab_size, d_model=512, num_heads=8,
    num_layers=6, d_ff=2048, max_seq_len=5000, dropout=0.1
):
    super().__init__()

    # Token and positional embeddings
    self.token_embedding = TokenEmbedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_seq_len)

    # Stack of transformer blocks
    self.blocks = nn.ModuleList([
        TransformerBlock(d_model, num_heads, d_ff, dropout)
        for _ in range(num_layers)
    ])

    # Final layer norm and output projection
    self.ln_f = nn.LayerNorm(d_model)
    self.output_proj = nn.Linear(d_model, vocab_size)
def forward(self, x, mask=None):
    # Create causal mask if not provided
    if mask is None:
        mask = self.create_causal_mask(x.size(1)).to(x.device)

    # 1. Embed tokens and add positions
    x = self.token_embedding(x)      # (batch, seq) → (batch, seq, d_model)
    x = self.pos_encoding(x)

    # 2. Pass through all transformer blocks
    for block in self.blocks:
        x = block(x, mask=mask)      # (batch, seq, d_model) → (batch, seq, d_model)

    # 3. Final normalization and projection to vocabulary
    x = self.ln_f(x)
    logits = self.output_proj(x)     # (batch, seq, d_model) → (batch, seq, vocab_size)

    return logits

Training the model: During training, we feed the model sequences of text and ask it to predict the next token at each position. We compare its predictions (logits) against the actual next tokens using cross-entropy loss, then use backpropagation to adjust all the weights (embeddings, attention projections, FFN weights, etc.). After training on billions of tokens, the model learns to predict plausible next words based on context.

Model scale: Our implementation uses 6 layers with d_model=512, similar to the original transformer paper. For comparison, GPT-3 has 96 layers with d_model=12,288. The architecture scales beautifully—the same fundamental components work at wildly different scales!

📝 Implementation: src/transformer/model.py

Step 7: Training at Scale

Building the transformer architecture is only half the battle. To train it effectively, we need techniques that make training stable, prevent over fitting, and work within the constraints of hobby-scale hardware. This section covers two critical techniques: gradient accumulation and validation splits.

The Challenge: Small Batches, Noisy Training

What is a batch? During training, we process multiple examples together in a "batch." The model makes predictions for all examples, we compute the average loss, then we calculate gradients and update weights. Larger batches give us more stable gradient estimates because we're averaging over more examples.

The problem with small batches: On hobby hardware (like an M1 Mac or consumer GPU), we're limited to small batches—typically just 8 sequences at a time. Small batches lead to noisy gradients: each batch gives a slightly different signal about which direction to update the weights, causing erratic training.

Memory bottleneck: Why can't we just use bigger batches? Each example in a batch requires storing activations in memory for the backward pass. M1 Macs have ~8GB unified memory, and a batch of 8 sequences already uses ~4GB. Doubling to 16 would run out of memory!

Gradient Accumulation: Large Batches Without the Memory Cost

The key insight: We don't need to process all examples simultaneously! Gradient accumulation lets us simulate large batch sizes by accumulating gradients over multiple small batches before updating weights.

How it works:

  1. Process batch 1: Forward pass → Loss → Backward pass → Store gradients (don't update yet!)
  2. Process batch 2: Forward pass → Loss → Backward pass → Add gradients to stored ones
  3. Repeat for N batches (e.g., 16 times)
  4. Update weights: Use the accumulated (averaged) gradients
Gradient accumulation comparison diagram

Why this works mathematically: Gradients are linear, so averaging gradients from N separate batches gives the same result as computing the gradient on one large batch containing all N×batch_size examples. The key formula:

∇(L₁ + L₂ + ... + Lₙ) = ∇L₁ + ∇L₂ + ... + ∇Lₙ

By accumulating gradients over 16 batches of 8 sequences each, we get gradients equivalent to a batch of 128 sequences—16× more stable!—while only ever holding 8 sequences in memory at once.

# Without accumulation (noisy)
for batch in dataloader:
    loss = compute_loss(batch)
    loss.backward()           # Compute gradients
    optimizer.step()          # Update every batch (noisy!)
    optimizer.zero_grad()

# With accumulation (stable)
accumulation_steps = 16
for i, batch in enumerate(dataloader):
    loss = compute_loss(batch)
    loss = loss / accumulation_steps  # Scale for correct averaging
    loss.backward()                   # Accumulate gradients

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()              # Update every 16 batches (stable!)
        optimizer.zero_grad()

Validation: Detecting Overfitting

The problem: Memorization vs. Learning

Imagine a student preparing for an exam. They could:

  • Memorize answers to practice problems → Fails on new problems (overfitting)
  • Learn concepts from practice problems → Succeeds on new problems (good generalization)

The same happens with neural networks. As training progresses, the model might start memorizing the training data instead of learning general patterns. This is called overfitting.

The solution: Validation split

We set aside 10% of our data that the model never sees during training. After each epoch, we evaluate the model on this "validation" data. If the model is truly learning patterns (not memorizing), it should perform well on both training and validation data.

Training vs validation loss curves

How to interpret the curves:

✓ Good Training

Train: 5.0 → 4.0 → 3.0
Val: 5.2 → 4.2 → 3.2

Both losses decreasing together. Model is learning general patterns that work on new data!

⚠ Underfitting

Train: 5.0 → 4.8 → 4.7
Val: 5.2 → 5.0 → 4.9

Both losses barely improving. Model is too simple or needs more training epochs.

⚠ Overfitting

Train: 5.0 → 3.0 → 1.5
Val: 5.2 → 3.5 → 4.0

Training loss decreasing but validation increasing. Model is memorizing training data!

# Training with validation
for epoch in range(num_epochs):
    # Training phase
    model.train()
    for batch in train_dataloader:
        # ... forward, backward, update ...

    # Validation phase (no weight updates!)
    model.eval()
    with torch.no_grad():
        for batch in val_dataloader:
            val_loss = compute_loss(batch)
            # Just measure, don't update

    print(f"Train loss: {train_loss:.2f}, Val loss: {val_loss:.2f}")

    # Check for overfitting
    if val_loss > train_loss * 1.3:
        print("Warning: Possible overfitting!")

Implementation in this project: Our training script automatically splits FineWeb into 90% training and 10% validation using a deterministic hash-based split. After each epoch, you'll see both training and validation metrics, along with interpretation hints to help you understand if your model is learning well!

Expected Improvements

With gradient accumulation and validation:

  • 20-30% lower final loss due to stable training
  • Smoother training curves that are easier to debug
  • Confidence in generalization by monitoring validation
  • Early stopping when validation stops improving
  • Works on hobby hardware without expensive GPUs

Step 8: Fast Generation with KV-Cache

The Problem: Slow Autoregressive Generation

When generating text, transformers produce one token at a time. After generating each token, we feed the entire sequence back through the model to predict the next token. This means we repeatedly recompute the same values!

Example without cache:

# Generate "The cat sat"
# Step 1: Generate token 3
Input: [The, cat]
Compute: K[The], V[The], K[cat], V[cat]
Output: "sat" ✓

# Step 2: Generate token 4
Input: [The, cat, sat]
Compute: K[The], V[The], K[cat], V[cat], K[sat], V[sat]  ← Redundant!
Output: "on"

# Step 3: Generate token 5
Input: [The, cat, sat, on]
Compute: K[The], V[The], K[cat], V[cat], K[sat], V[sat], K[on], V[on]  ← Redundant!
Output: "the"

For generating n tokens, we process 1 + 2 + 3 + ... + n = O(n²) tokens total. Very slow!

The Solution: KV-Cache

Key Insight: In attention, K (Key) and V (Value) for past tokens never change! Only the new token's query matters. We can cache K and V from previous steps and reuse them.

KV-cache speedup comparison diagram

How It Works

Two Modes:

  • PREFILL: Process initial prompt, compute and cache K, V for all tokens
  • DECODE: For each new token, compute only its K, V, concatenate with cached values
# PREFILL: Process prompt "The cat"
prompt = [The, cat]
K_all = [K_The, K_cat]  # Cache these!
V_all = [V_The, V_cat]  # Cache these!
Output: "sat"

# DECODE: Generate next token
new_token = [sat]
K_new = [K_sat]  # Only compute for new token
V_new = [V_sat]
K_all = concat(K_cached, K_new)  # = [K_The, K_cat, K_sat]
V_all = concat(V_cached, V_new)  # = [V_The, V_cat, V_sat]
Output: "on"

# Continue...

Memory vs Speed Tradeoff

Memory Cost: For each layer, we cache K and V tensors with shape (batch, num_heads, seq_len, d_k). For a 6-layer model with d_model=256, 4 heads, and 200-token sequence, this is only ~3 MB per example. Very affordable!

Speed Benefit: Reduces time complexity from O(n²) to O(n) for generating n tokens. Typical speedups:

  • Short sequences (10-20 tokens): 2-5x faster
  • Medium sequences (50-100 tokens): 10-20x faster
  • Long sequences (200+ tokens): 20-50x faster

Why ALL production LLMs use KV-cache: The memory cost is tiny compared to the model weights, but the speed improvement is massive. Every production system (GPT, Claude, etc.) uses KV-cache for generation!

Using KV-Cache

# KV-cache is enabled by default!
generated = model.generate(
    start_tokens,
    max_length=100,
    sampling_strategy="greedy",
    use_cache=True  # ← Default!
)

# Disable cache (for debugging/comparison)
generated = model.generate(
    start_tokens,
    max_length=100,
    use_cache=False  # ← Much slower!
)

# Benchmark the speedup yourself
python commands/benchmark_generation.py

Implementation Detail: The cache must correctly handle positional encodings! When processing token at position N, it must receive position embedding for N, not 0. Our implementation tracks the cache length and adjusts positions automatically.

Step 9: Model Interpretability

Now that we've built and trained a transformer, how do we understand what it has learned? Mechanistic interpretability provides tools to peek inside the "black box" and discover the circuits and patterns the model uses.

This section covers four powerful techniques: Logit Lens, Attention Analysis, Induction Heads, and Activation Patching.

What is Mechanistic Interpretability?

Instead of just asking "does the model work?", we ask:

  • When does the model "know" the answer? (which layer?)
  • How does information flow through the network?
  • Which components are responsible for specific behaviors?
  • What patterns or circuits has the model learned?

This connects to cutting-edge research from Anthropic, OpenAI, and academic labs exploring how LLMs actually work under the hood.

Logit Lens: Seeing Predictions Evolve

The logit lens technique lets us visualize what the model would predict if we stopped at each layer.

How It Works

Normally, we only see the final output:

Input → Layer 1 → Layer 2 → ... → Layer N
                 → Unembed → Logits

With logit lens, we apply unembedding at each layer:

Input → Layer 1 → [Unembed] → "What now?"
       → Layer 2 → [Unembed] → "What now?"
       → Layer 3 → [Unembed] → "What now?"

Example Insight

Input: "The capital of France is"

Layer 0: "the" (15%), "a" (12%)
→ Generic, common words

Layer 2: "located" (18%), "Paris" (15%)
→ Starting to understand context

Layer 4: "Paris" (65%), "French" (10%)
→ Confident, correct answer!

Layer 6: "Paris" (72%), "France" (8%)
→ Final refinement

Key Insight: The model "knows" Paris by Layer 4. Later layers just refine the distribution.

Try It Yourself

Our implementation provides three ways to explore:

# Demo mode - educational examples
$ python main.py interpret logit-lens checkpoints/model.pt --demo
# Analyze specific text
$ python main.py interpret logit-lens checkpoints/model.pt \
--text "The Eiffel Tower is in"
# Interactive mode
$ python main.py interpret logit-lens checkpoints/model.pt --interactive

Beautiful Terminal Output: Uses the Rich library to display color-coded predictions in tables. High-probability predictions are highlighted in green, making it easy to see when the model converges on the right answer.

Attention Analysis: What Do Heads Focus On?

The attention analysis tool reveals what each attention head is looking at when processing text.

How It Works

Attention weights show which tokens each position "attends to". By analyzing these patterns across heads, we discover specialized behaviors:

  • Previous token heads: Always look at position i-1
  • Uniform heads: Spread attention evenly (averaging information)
  • Start token heads: Focus on the beginning of the sequence
  • Sparse heads: Concentrate on very few key tokens

Example Discovery

Input: "The cat sat on the mat"

Head 2.3 (Previous Token):
"cat" attends to "The" (100%)
"sat" attends to "cat" (100%)
→ Implements a previous-token circuit!

Head 4.1 (Uniform):
Each token: 16.7% to all positions
→ Averages information uniformly

# Analyze a specific head
$ python main.py interpret attention checkpoints/model.pt \
--text "Hello world" --layer 2 --head 3
# Find all previous-token heads
$ python main.py interpret attention checkpoints/model.pt \
--text "Hello world" # Shows pattern summary

Induction Heads: Pattern Matching Circuits

The induction head detector finds circuits that implement in-context learning - the ability to copy from earlier patterns.

What Are Induction Heads?

Given a repeated pattern like:

Input: "A B C ... A B [?]"
Prediction: "C"

Induction heads learn to predict C by recognizing the repeated "A B" pattern and copying what came after the first occurrence.

The Circuit

Induction typically involves two heads working together:

1. Previous Token Head (Layer L):
At position i, attends to i-1
Creates representation of "what came before"

2. Induction Head (Layer L+1):
Queries for matches to previous token
Attends to what came AFTER those matches
Predicts the next token

# Detect induction heads across all layers
$ python main.py interpret induction-heads checkpoints/model.pt
# Custom parameters: fewer tests, longer sequences
$ python main.py interpret induction-heads checkpoints/model.pt \
--num-sequences 50 --seq-length 40 --top-k 10

Why It Matters: Induction heads are the first clearly-identified circuit in transformers. They're crucial for few-shot learning and emerge suddenly during training ("grokking"). Almost all transformer language models develop induction heads!

Activation Patching: Causal Interventions

The activation patching tool performs causal experiments to identify which components are truly responsible for specific behaviors.

The Question

We can observe what the model does, but which parts are actually causing the behavior?

Activation patching answers this through intervention experiments:

  1. Run model on "clean" input (correct behavior)
  2. Run model on "corrupted" input (incorrect behavior)
  3. For each component, swap clean activations into corrupted run
  4. Measure how much this restores correct behavior

High recovery = that component is causally important!

Example Experiment

Clean:     "The Eiffel Tower is in"
           → Predicts: "Paris" (85%)

Corrupted: "The Empire State is in"
           → Predicts: "New York" (78%)

Test: Patch Layer 4 activations
      from clean → corrupted
           → Predicts: "Paris" (82%)

Result: Layer 4 recovery = 90%
        Layer 4 is CRITICAL!
# Test which layers are causally important
$ python main.py interpret patch checkpoints/model.pt \
--clean "The Eiffel Tower is in" \
--corrupted "The Empire State Building is in" \
--target "Paris"

Why It's Powerful: This is causal evidence, not just correlation. If patching a layer recovers the behavior, that layer is provably necessary for that specific computation. This technique has been used to locate where factual knowledge is stored in large language models!

Learn More

Explore the implementation files - each includes comprehensive documentation explaining the theory and methods:

Try It Yourself

Ready to train and experiment with your own transformer? The complete implementation is available on GitHub with everything you need to get started.

Quick Start

$ git clone https://github.com/zhubert/transformer.git
$ cd transformer
# Install dependencies with uv
$ make install
# Launch interactive CLI - easiest way to get started!
$ python main.py
Interactive CLI menu showing training, generation, evaluation, and interpretability options

Interactive Mode: The CLI provides a beautiful, arrow-key navigated menu system. No flags to memorize! Just run python main.py and the interface guides you through:

  • Training models with configurable presets
  • Generating text with various sampling strategies
  • Evaluating and comparing checkpoints
  • Analyzing model internals (attention, logit lens, etc.)
  • Downloading training data for offline use

For advanced users, all operations are also available via traditional command-line flags - see the README for details.

Want something more production-ready? Check out Andrej Karpathy's nanochat - a polished, optimized GPT implementation that's ready to use for real applications.