🎉 75% of content is free forever — Unlock Premium from $10/mo →
CW
Search courses…
💼 Servicesℹ️ About✉️ ContactView Pricing Plansfrom $10

Attention Mechanism: Scaled Dot-Product, Cross-Attention, Flash Attention — Asked at OpenAI & Anthropic

Deep Learning Premium InterviewsAttention Mechanisms⭐ Premium

Advertisement

OpenAI & Anthropic

Attention Mechanism: From Scaled Dot-Product to Flash Attention

Premium Interview Preparation — Attention Mechanism Mastery

🎯 The Interview Question

"Explain the attention mechanism in detail, including the mathematical formulation of scaled dot-product attention. What is cross-attention and how does it differ from self-attention? Describe Flash Attention and how it achieves the same results as standard attention but with better memory efficiency. What are the recent advances in efficient attention mechanisms?"

This question tests deep understanding of the mechanism that powers modern AI — essential for roles at OpenAI and Anthropic.


📚 Detailed Answer

The Attention Mechanism: Intuition

Attention is a mechanism for dynamically weighting information based on relevance. Given a query and a set of key-value pairs, attention computes a weighted sum of values, where weights are determined by query-key compatibility.

Intuition: When reading a sentence, you "attend" to relevant words to understand context. "The cat sat on the mat" — to understand "sat", you attend to "cat" (subject) and "mat" (location).

Scaled Dot-Product Attention: Mathematical Formulation

Given:

  • Query QRn×dk\mathbf{Q} \in \mathbb{R}^{n \times d_k}
  • Key KRm×dk\mathbf{K} \in \mathbb{R}^{m \times d_k}
  • Value VRm×dv\mathbf{V} \in \mathbb{R}^{m \times d_v}

The attention output is:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\mathbf{V}

Step-by-step:

  1. Compute compatibility scores: S=QKTRn×m\mathbf{S} = \mathbf{Q}\mathbf{K}^T \in \mathbb{R}^{n \times m}

    • Sij=qiTkjS_{ij} = \mathbf{q}_i^T \mathbf{k}_j measures similarity between query ii and key jj
  2. Scale: S=Sdk\mathbf{S}' = \frac{\mathbf{S}}{\sqrt{d_k}}

    • Prevents large values that cause softmax saturation
  3. Normalize: A=softmax(S)\mathbf{A} = \text{softmax}(\mathbf{S}')

    • Row-wise softmax ensures weights sum to 1
  4. Aggregate: Output=AV\text{Output} = \mathbf{A}\mathbf{V}

    • Weighted sum of values

💡

The scaling factor dk\sqrt{d_k} is crucial. Without it, for large dkd_k, dot products grow in magnitude, pushing softmax into regions with vanishing gradients. The variance of dot products of random vectors with unit variance is dkd_k, so dividing by dk\sqrt{d_k} normalizes the variance to 1.

Types of Attention

Self-Attention

Queries, keys, and values all come from the same sequence:

Q=XWQ,K=XWK,V=XWV\mathbf{Q} = \mathbf{X}\mathbf{W}^Q, \quad \mathbf{K} = \mathbf{X}\mathbf{W}^K, \quad \mathbf{V} = \mathbf{X}\mathbf{W}^V

Each token attends to all other tokens in the same sequence. Used in encoder and decoder of Transformers.

Cross-Attention

Queries come from one sequence, keys and values from another:

Q=XdecWQ,K=XencWK,V=XencWV\mathbf{Q} = \mathbf{X}_{dec}\mathbf{W}^Q, \quad \mathbf{K} = \mathbf{X}_{enc}\mathbf{W}^K, \quad \mathbf{V} = \mathbf{X}_{enc}\mathbf{W}^V

Used in encoder-decoder models (e.g., T5, BART) to attend to the encoded input.

Causal (Masked) Attention

Prevents tokens from attending to future positions:

Maskij={0if jiif j>i\text{Mask}_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}
Attention=softmax(QKTdk+Mask)V\text{Attention} = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}} + \text{Mask}\right)\mathbf{V}

Essential for autoregressive generation (GPT-style models).

Flash Attention: Memory-Efficient Exact Attention

Standard attention materializes the full n×nn \times n attention matrix, requiring O(n2)O(n^2) memory. Flash Attention achieves the same result with O(n)O(n) memory.

Key Insight: The softmax can be computed in a streaming fashion using the log-sum-exp trick:

softmax(xi)=exijexj=eximjexjm\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}} = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}

where m=maxjxjm = \max_j x_j is the running maximum.

Algorithm:

  1. Divide Q, K, V into blocks
  2. For each block of Q:
    • Load block of K, V into fast SRAM
    • Compute block-wise attention
    • Update output using online softmax
  3. No n×nn \times n matrix materialized

Memory savings: O(n)O(n) vs O(n2)O(n^2) Speedup: 2-4× on modern GPUs due to better memory access patterns

Advanced Attention Variants

Grouped Query Attention (GQA)

Shares key-value heads across query heads to reduce memory:

GQA: hq query heads, hg groups, each group shares KV\text{GQA: } h_q \text{ query heads, } h_g \text{ groups, each group shares KV}

Used in LLaMA 2, Mistral. Reduces KV cache by factor of hq/hgh_q/h_g.

Multi-Query Attention (MQA)

Extreme case: all query heads share one KV head:

MQA: hq query heads, 1 KV head\text{MQA: } h_q \text{ query heads, } 1 \text{ KV head}

10× faster inference than MHA, slight quality loss.

Flash Attention 2 & 3

  • Flash Attention 2: Better parallelism across sequence length dimension
  • Flash Attention 3: FP8 support, asynchronous operations on Hopper GPUs

Attention Complexity Comparison

VariantTimeMemoryUse Case
StandardO(n2d)O(n^2 d)O(n2)O(n^2)Training
Flash AttentionO(n2d)O(n^2 d)O(n)O(n)Training/Inference
Sparse (Longformer)O(nn)O(n \sqrt{n})O(n)O(n)Very long sequences
LinearO(nd2)O(n d^2)O(nd)O(nd)Ultra-long sequences
GQAO(n2d)O(n^2 d)O(n2/hg)O(n^2/h_g)Efficient inference

Practical Implementation Tips

Follow-Up Questions

Q: Why not use additive attention instead of dot-product? A: Additive attention (using MLP) is more expressive but slower. Dot-product attention can be implemented as matrix multiplication, which is hardware-optimized.

Q: How does Flash Attention handle the causal mask? A: By skipping computations for masked positions during the block-wise processing, achieving the same speedup as without masking.

Q: What is the relationship between attention and kernel methods? A: Attention can be viewed as a kernel function k(q,k)=exp(qTk/d)k(q,k) = \exp(q^T k/\sqrt{d}). Linear attention replaces this with factorizable kernels to achieve linear complexity.

Related Topics

Advertisement