AIMaks

Implementing Attention from Scratch

50 min readnotebookAttention Mechanisms
3 of 28Transformer Architecture Deep Dive

Implementing Attention from Scratch

Time to actually code it. This notebook builds scaled dot-product attention from first principles in PyTorch, verifies it against PyTorch's built-in F.scaled_dot_product_attention, plots an attention matrix on a real input, and finishes with a causal-masked variant ready to drop into a decoder. Same code generalises to multi-head attention in Lesson 7.

code
pip install "torch==2.5.1" "matplotlib==3.9.2" "numpy==2.1.2"
CPU works for the toy examples in this notebook.

1. The Imports

code
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(0)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

2. Attention From Scratch

code
def scaled_dot_product_attention(Q, K, V, mask=None):
    # Q, K: (..., n, d_k);  V: (..., n, d_v)
    d_k = Q.size(-1)
    scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)   # (..., n, n)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    attn = F.softmax(scores, dim=-1)                    # (..., n, n)
    out = attn @ V                                      # (..., n, d_v)
    return out, attn

Six lines. Note that we return the attention matrix alongside the output — useful for visualisation, and free since we already computed it.

3. Smoke Test

code
n, d_k, d_v = 4, 8, 8
Q = torch.randn(n, d_k)
K = torch.randn(n, d_k)
V = torch.randn(n, d_v)

out, attn = scaled_dot_product_attention(Q, K, V)
print("output shape    :", out.shape)
print("attention shape :", attn.shape)
print("each row sums to:", attn.sum(dim=-1))
# tensor([1.0000, 1.0000, 1.0000, 1.0000])

Output is (4, 8); attention is (4, 4); each row of the attention matrix sums to 1, confirming row-wise softmax.

4. Verify Against PyTorch's Built-In

code
# PyTorch 2.0+ ships an optimized scaled dot-product attention
ours, _    = scaled_dot_product_attention(Q.unsqueeze(0), K.unsqueeze(0), V.unsqueeze(0))
theirs     = F.scaled_dot_product_attention(Q.unsqueeze(0), K.unsqueeze(0), V.unsqueeze(0))
print("max abs diff:", (ours - theirs).abs().max().item())
# very small, should be < 1e-6

Numerical agreement to floating-point precision. This is the moment to internalize: PyTorch's optimized kernel computes exactly what we just hand-wrote — the production version is just faster. From here forward, you can use the built-in everywhere; we wrote it from scratch once for understanding.

5. A Concrete Self-Attention Example

code
# A toy "sentence" of 4 tokens, each with a 4-dim embedding
sentence = torch.tensor([
    [1.0, 0.0, 0.0, 0.0],   # "the"
    [0.0, 1.0, 0.0, 0.0],   # "cat"
    [0.0, 0.0, 1.0, 0.0],   # "sat"
    [0.0, 0.0, 0.0, 1.0],   # "mat"
])

# Self-attention: Q, K, V all derived from the same input
W_Q = torch.randn(4, 4)
W_K = torch.randn(4, 4)
W_V = torch.randn(4, 4)

Q = sentence @ W_Q
K = sentence @ W_K
V = sentence @ W_V

out, attn = scaled_dot_product_attention(Q, K, V)
print("attention matrix:")
print(attn.round_(decimals=3))

This is self-attention in its simplest possible form. The W_Q, W_K, W_V matrices are the parameters that would normally be learned; here they're random.

6. Visualizing Attention

code
tokens = ["the", "cat", "sat", "mat"]

fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(attn.detach().numpy(), cmap="Blues", vmin=0, vmax=1)
ax.set_xticks(range(len(tokens))); ax.set_xticklabels(tokens)
ax.set_yticks(range(len(tokens))); ax.set_yticklabels(tokens)
ax.set_xlabel("attended-to (key)")
ax.set_ylabel("query")
for i in range(len(tokens)):
    for j in range(len(tokens)):
        ax.text(j, i, f"{attn[i, j]:.2f}", ha="center", va="center",
                color="black" if attn[i, j] < 0.5 else "white")
plt.tight_layout(); plt.show()

Visualizing attention matrices is the standard way to interpret a trained transformer. With random weights you'll see no meaningful pattern — but in a trained model on real text, you'd see structures like punctuation attending to the previous word, pronouns attending to their antecedents, etc.

7. Wrap as a nn.Module

code
class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)
        out, attn = scaled_dot_product_attention(Q, K, V, mask)
        return out, attn

attn_layer = SelfAttention(d_model=64)
x = torch.randn(2, 10, 64)              # (batch, seq, d_model)
out, attn = attn_layer(x)
print(out.shape, attn.shape)
# torch.Size([2, 10, 64]) torch.Size([2, 10, 10])

Same logic, dressed up as a proper PyTorch module with learnable parameters. This is the building block we'll scale up to multi-head attention in Lesson 7.

8. Causal Self-Attention

code
def causal_mask(n):
    # 1 where allowed, 0 where blocked
    return torch.tril(torch.ones(n, n))

n = 5
mask = causal_mask(n).to(DEVICE)
print(mask)
# tensor([[1., 0., 0., 0., 0.],
#         [1., 1., 0., 0., 0.],
#         [1., 1., 1., 0., 0.],
#         [1., 1., 1., 1., 0.],
#         [1., 1., 1., 1., 1.]])

# Use it
Q = torch.randn(1, n, 8); K = torch.randn(1, n, 8); V = torch.randn(1, n, 8)
out, attn = scaled_dot_product_attention(Q, K, V, mask=mask)
print("causal attention:")
print(attn.squeeze(0).round_(decimals=3))

The lower-triangular mask blocks each query from attending to future keys — exactly what GPT-style decoders need. After the softmax, each row's mass is distributed only over the past + current positions.

9. Common Pitfalls

10. Exercises

Up next · Self-Attention vs Cross-Attention