πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Attention Mechanism

TransformersSelf-Attention and Multi-Head Attention🟒 Free Lesson

Advertisement

Attention Mechanism

The attention mechanism is the cornerstone of modern transformer architectures. It allows models to dynamically focus on relevant parts of the input when producing each output element.

Scaled Dot-Product Attention

The core attention function computes a weighted sum of values based on compatibility between queries and keys.

DfScaled Dot-Product Attention

Where:

  • Q (Query): What we are looking for β€” shape (n, d_k)
  • K (Key): What we offer to match against β€” shape (m, d_k)
  • V (Value): What we actually retrieve β€” shape (m, d_v)
  • d_k: Dimensionality of keys (scaling factor)

The scaling factor √d_k prevents the dot products from growing too large in magnitude, which would push the softmax into regions with extremely small gradients.

Why Scaling Matters

Without scaling, for large d_k the dot products tend to have large magnitudes. This causes the softmax to saturate, producing gradients near zero. The variance of q · k is approximately d_k when components are independent with mean 0 and variance 1, so dividing by √d_k normalizes the variance to 1.

Variance Analysis

Var(qβ‹…k)=βˆ‘i=1dkVar(qiki)=dk\text{Var}(q \cdot k) = \sum_{i=1}^{d_k} \text{Var}(q_i k_i) = d_k

Self-Attention

In self-attention, Q, K, and V all come from the same sequence. Each token attends to every other token in the sequence, enabling the model to capture dependencies regardless of distance.

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, V)
    return output, attention_weights

# Example usage
batch_size, seq_len, d_k = 2, 10, 64
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)

output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}")   # (2, 10, 64)
print(f"Weights shape: {weights.shape}") # (2, 10, 10)

Multi-Head Attention

Multi-head attention runs multiple attention operations in parallel, each with different learned projections, then concatenates and projects the results.

DfMulti-Head Attention

DfIndividual Head

Where:

  • h: Number of attention heads
  • W_i^Q, W_i^K: Projection matrices of shape (d_model, d_k) where d_k = d_model / h
  • W_i^V: Value projection matrix of shape (d_model, d_v)
  • W^O: Output projection matrix of shape (hd_v, d_model)

Complete Multi-Head Attention Implementation

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections and reshape to (batch, heads, seq_len, d_k)
        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention to values
        context = torch.matmul(attn_weights, V)

        # Concatenate heads and project
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(context)

        return output, attn_weights

# Usage
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)  # (batch, seq_len, d_model)
out, weights = mha(x, x, x)  # Self-attention
print(f"Output: {out.shape}, Weights: {weights.shape}")

Causal (Masked) Attention

For autoregressive models like GPT, we mask future positions so each token can only attend to itself and previous tokens.

def create_causal_mask(seq_len):
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)

mask = create_causal_mask(5)
print(mask)
# tensor([[[[1, 0, 0, 0, 0],
#           [1, 1, 0, 0, 0],
#           [1, 1, 1, 0, 0],
#           [1, 1, 1, 1, 0],
#           [1, 1, 1, 1, 1]]]])

Attention Variants Comparison

VariantComplexityUse CaseKey Feature
Scaled Dot-ProductO(nΒ²Β·d)GeneralStandard attention
Causal/MaskedO(nΒ²Β·d)AutoregressiveNo future information
Cross-AttentionO(nΒ·mΒ·d)Encoder-DecoderDifferent Q and K,V
Linear AttentionO(nΒ·dΒ²)Long sequencesKernel approximation
Sparse AttentionO(n·√n)Long sequencesLocal + global patterns

Attention Score Computation

Key Takeaways

  1. Attention is all you need β€” it replaces recurrence and convolutions entirely
  2. Scaling prevents gradient vanishing in the softmax
  3. Multi-head attention captures different types of relationships
  4. Self-attention enables direct modeling of pairwise token relationships
  5. Masking controls information flow (causal for generation, padding for batching)
⭐

Premium Content

Attention Mechanism

Unlock this lesson and 900+ advanced tutorials with a Premium plan.

🎯End-to-end Projects
πŸ’ΌInterview Prep
πŸ“œCertificates
🀝Community Access

Already a member? Log in

Need Expert NLP Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement