The attention mechanism is the core building block of the Transformer architecture. In this post, we’ll walk through the math and intuition behind self-attention.
Scaled Dot-Product Attention
Given a set of queries , keys , and values , the attention function is defined as:
where is the dimension of the key vectors. The scaling factor prevents the dot products from growing too large, which would push the softmax into regions with extremely small gradients.
Why Scale?
Consider two vectors with components drawn independently from . Their dot product has mean and variance . As grows, the variance increases, causing the softmax to saturate.
Multi-Head Attention
Instead of computing a single attention function, we project the queries, keys, and values times with different learned linear projections:
where each head is:
This allows the model to attend to information from different representation subspaces.
Implementation in Python
Here’s a minimal implementation using PyTorch:
import torch
import torch.nn as nn
import math
class SelfAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
attn = torch.softmax(attn, dim=-1)
out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
return self.W_o(out)
A Quick Bash Example
You can check your GPU availability for training:
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
Key Takeaways
The beauty of attention is that it allows every position in the sequence to directly attend to every other position, making it far more parallelizable than recurrent architectures.
- Dot-product attention computes relevance scores between all pairs of positions.
- Scaling prevents gradient vanishing in the softmax.
- Multi-head attention lets the model capture different types of relationships simultaneously.
The Transformer has become the foundation for nearly all modern large language models, from BERT to GPT and beyond.