Implementing Attention from Scratch
3 of 28Transformer Architecture Deep Dive
Implementing Attention from Scratch
Time to actually code it. This notebook builds scaled
dot-product attention from first principles in PyTorch,
verifies it against PyTorch's built-in
F.scaled_dot_product_attention, plots an
attention matrix on a real input, and finishes with a
causal-masked variant ready to drop into a decoder. Same
code generalises to multi-head attention in Lesson 7.
pip install "torch==2.5.1" "matplotlib==3.9.2" "numpy==2.1.2"1. The Imports
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
torch.manual_seed(0)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
2. Attention From Scratch
def scaled_dot_product_attention(Q, K, V, mask=None):
# Q, K: (..., n, d_k); V: (..., n, d_v)
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k) # (..., n, n)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1) # (..., n, n)
out = attn @ V # (..., n, d_v)
return out, attn
Six lines. Note that we return the attention matrix alongside the output — useful for visualisation, and free since we already computed it.
3. Smoke Test
n, d_k, d_v = 4, 8, 8
Q = torch.randn(n, d_k)
K = torch.randn(n, d_k)
V = torch.randn(n, d_v)
out, attn = scaled_dot_product_attention(Q, K, V)
print("output shape :", out.shape)
print("attention shape :", attn.shape)
print("each row sums to:", attn.sum(dim=-1))
# tensor([1.0000, 1.0000, 1.0000, 1.0000])
Output is (4, 8); attention is (4, 4); each row of the attention matrix sums to 1, confirming row-wise softmax.
4. Verify Against PyTorch's Built-In
# PyTorch 2.0+ ships an optimized scaled dot-product attention
ours, _ = scaled_dot_product_attention(Q.unsqueeze(0), K.unsqueeze(0), V.unsqueeze(0))
theirs = F.scaled_dot_product_attention(Q.unsqueeze(0), K.unsqueeze(0), V.unsqueeze(0))
print("max abs diff:", (ours - theirs).abs().max().item())
# very small, should be < 1e-6
Numerical agreement to floating-point precision. This is the moment to internalize: PyTorch's optimized kernel computes exactly what we just hand-wrote — the production version is just faster. From here forward, you can use the built-in everywhere; we wrote it from scratch once for understanding.
5. A Concrete Self-Attention Example
# A toy "sentence" of 4 tokens, each with a 4-dim embedding
sentence = torch.tensor([
[1.0, 0.0, 0.0, 0.0], # "the"
[0.0, 1.0, 0.0, 0.0], # "cat"
[0.0, 0.0, 1.0, 0.0], # "sat"
[0.0, 0.0, 0.0, 1.0], # "mat"
])
# Self-attention: Q, K, V all derived from the same input
W_Q = torch.randn(4, 4)
W_K = torch.randn(4, 4)
W_V = torch.randn(4, 4)
Q = sentence @ W_Q
K = sentence @ W_K
V = sentence @ W_V
out, attn = scaled_dot_product_attention(Q, K, V)
print("attention matrix:")
print(attn.round_(decimals=3))
This is self-attention in its simplest possible form. The
W_Q, W_K, W_V
matrices are the parameters that would normally be learned;
here they're random.
6. Visualizing Attention
tokens = ["the", "cat", "sat", "mat"]
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(attn.detach().numpy(), cmap="Blues", vmin=0, vmax=1)
ax.set_xticks(range(len(tokens))); ax.set_xticklabels(tokens)
ax.set_yticks(range(len(tokens))); ax.set_yticklabels(tokens)
ax.set_xlabel("attended-to (key)")
ax.set_ylabel("query")
for i in range(len(tokens)):
for j in range(len(tokens)):
ax.text(j, i, f"{attn[i, j]:.2f}", ha="center", va="center",
color="black" if attn[i, j] < 0.5 else "white")
plt.tight_layout(); plt.show()
Visualizing attention matrices is the standard way to interpret a trained transformer. With random weights you'll see no meaningful pattern — but in a trained model on real text, you'd see structures like punctuation attending to the previous word, pronouns attending to their antecedents, etc.
7. Wrap as a nn.Module
class SelfAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None):
Q = self.W_Q(x)
K = self.W_K(x)
V = self.W_V(x)
out, attn = scaled_dot_product_attention(Q, K, V, mask)
return out, attn
attn_layer = SelfAttention(d_model=64)
x = torch.randn(2, 10, 64) # (batch, seq, d_model)
out, attn = attn_layer(x)
print(out.shape, attn.shape)
# torch.Size([2, 10, 64]) torch.Size([2, 10, 10])
Same logic, dressed up as a proper PyTorch module with learnable parameters. This is the building block we'll scale up to multi-head attention in Lesson 7.
8. Causal Self-Attention
def causal_mask(n):
# 1 where allowed, 0 where blocked
return torch.tril(torch.ones(n, n))
n = 5
mask = causal_mask(n).to(DEVICE)
print(mask)
# tensor([[1., 0., 0., 0., 0.],
# [1., 1., 0., 0., 0.],
# [1., 1., 1., 0., 0.],
# [1., 1., 1., 1., 0.],
# [1., 1., 1., 1., 1.]])
# Use it
Q = torch.randn(1, n, 8); K = torch.randn(1, n, 8); V = torch.randn(1, n, 8)
out, attn = scaled_dot_product_attention(Q, K, V, mask=mask)
print("causal attention:")
print(attn.squeeze(0).round_(decimals=3))
The lower-triangular mask blocks each query from attending to future keys — exactly what GPT-style decoders need. After the softmax, each row's mass is distributed only over the past + current positions.