Attention Mechanism
The attention mechanism is the cornerstone of modern transformer architectures. It allows models to dynamically focus on relevant parts of the input when producing each output element.
Scaled Dot-Product Attention
The core attention function computes a weighted sum of values based on compatibility between queries and keys.
DfScaled Dot-Product Attention
Where:
- Q (Query): What we are looking for β shape
(n, d_k) - K (Key): What we offer to match against β shape
(m, d_k) - V (Value): What we actually retrieve β shape
(m, d_v) - d_k: Dimensionality of keys (scaling factor)
The scaling factor βd_k prevents the dot products from growing too large in magnitude, which would push the softmax into regions with extremely small gradients.
Why Scaling Matters
Without scaling, for large d_k the dot products tend to have large magnitudes. This causes the softmax to saturate, producing gradients near zero. The variance of q Β· k is approximately d_k when components are independent with mean 0 and variance 1, so dividing by βd_k normalizes the variance to 1.
Variance Analysis
Self-Attention
In self-attention, Q, K, and V all come from the same sequence. Each token attends to every other token in the sequence, enabling the model to capture dependencies regardless of distance.
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, attention_weights
# Example usage
batch_size, seq_len, d_k = 2, 10, 64
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}") # (2, 10, 64)
print(f"Weights shape: {weights.shape}") # (2, 10, 10)
Multi-Head Attention
Multi-head attention runs multiple attention operations in parallel, each with different learned projections, then concatenates and projects the results.
DfMulti-Head Attention
DfIndividual Head
Where:
- h: Number of attention heads
- W_i^Q, W_i^K: Projection matrices of shape
(d_model, d_k)whered_k = d_model / h - W_i^V: Value projection matrix of shape
(d_model, d_v) - W^O: Output projection matrix of shape
(hd_v, d_model)
Complete Multi-Head Attention Implementation
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear projections and reshape to (batch, heads, seq_len, d_k)
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
context = torch.matmul(attn_weights, V)
# Concatenate heads and project
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_o(context)
return output, attn_weights
# Usage
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512) # (batch, seq_len, d_model)
out, weights = mha(x, x, x) # Self-attention
print(f"Output: {out.shape}, Weights: {weights.shape}")
Causal (Masked) Attention
For autoregressive models like GPT, we mask future positions so each token can only attend to itself and previous tokens.
def create_causal_mask(seq_len):
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
mask = create_causal_mask(5)
print(mask)
# tensor([[[[1, 0, 0, 0, 0],
# [1, 1, 0, 0, 0],
# [1, 1, 1, 0, 0],
# [1, 1, 1, 1, 0],
# [1, 1, 1, 1, 1]]]])
Attention Variants Comparison
| Variant | Complexity | Use Case | Key Feature |
|---|---|---|---|
| Scaled Dot-Product | O(nΒ²Β·d) | General | Standard attention |
| Causal/Masked | O(nΒ²Β·d) | Autoregressive | No future information |
| Cross-Attention | O(nΒ·mΒ·d) | Encoder-Decoder | Different Q and K,V |
| Linear Attention | O(nΒ·dΒ²) | Long sequences | Kernel approximation |
| Sparse Attention | O(nΒ·βn) | Long sequences | Local + global patterns |
Attention Score Computation
Key Takeaways
- Attention is all you need β it replaces recurrence and convolutions entirely
- Scaling prevents gradient vanishing in the softmax
- Multi-head attention captures different types of relationships
- Self-attention enables direct modeling of pairwise token relationships
- Masking controls information flow (causal for generation, padding for batching)