Understanding the Attention Mechanism

The attention mechanism is the core building block of the Transformer architecture. In this post, we’ll walk through the math and intuition behind self-attention.

Scaled Dot-Product Attention

Given a set of queries QQ, keys KK, and values VV, the attention function is defined as:

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V

where dkd_k is the dimension of the key vectors. The scaling factor 1dk\frac{1}{\sqrt{d_k}} prevents the dot products from growing too large, which would push the softmax into regions with extremely small gradients.

Why Scale?

Consider two vectors q,kRdkq, k \in \mathbb{R}^{d_k} with components drawn independently from N(0,1)\mathcal{N}(0, 1). Their dot product qk=i=1dkqikiq \cdot k = \sum_{i=1}^{d_k} q_i k_i has mean 00 and variance dkd_k. As dkd_k grows, the variance increases, causing the softmax to saturate.

Multi-Head Attention

Instead of computing a single attention function, we project the queries, keys, and values hh times with different learned linear projections:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O

where each head is:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V)

This allows the model to attend to information from different representation subspaces.

Implementation in Python

Here’s a minimal implementation using PyTorch:

import torch
import torch.nn as nn
import math

class SelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape

        q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = torch.softmax(attn, dim=-1)

        out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(out)

A Quick Bash Example

You can check your GPU availability for training:

nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv

Key Takeaways

The beauty of attention is that it allows every position in the sequence to directly attend to every other position, making it far more parallelizable than recurrent architectures.

  1. Dot-product attention computes relevance scores between all pairs of positions.
  2. Scaling prevents gradient vanishing in the softmax.
  3. Multi-head attention lets the model capture different types of relationships simultaneously.

The Transformer has become the foundation for nearly all modern large language models, from BERT to GPT and beyond.