LESSON
Day 291: Multi-Head Attention
The core idea: multi-head attention lets the model look at the same sequence through several learned relevance patterns in parallel instead of forcing one single attention map to do every job.
Today's "Aha!" Moment
The insight: Self-attention already lets tokens consult other tokens. Multi-head attention asks the next natural question:
- why force one attention pattern to capture syntax, reference, locality, and long-range relations all at once?
Instead, the model learns several smaller attention views in parallel, then recombines them.
Why this matters: This is one of the reasons Transformers are expressive. Different heads can specialize in different relation types, scales, or interaction patterns, even though that specialization is learned rather than hand-coded.
Concrete anchor: In a sentence, one head may focus on local grammatical agreement, another on longer-distance subject-reference, and another on punctuation or separator structure. None of those views alone is the whole story, but together they give the layer richer context.
The practical sentence to remember:
Multi-head attention is multiple learned ways of relating the same tokens, run in parallel and then fused.
Why This Matters
One self-attention head can already compute a relevance map over the sequence. But a single map has limited representational bandwidth:
- it must choose one set of projections
- it produces one pattern of weights per token
- it mixes all relationship types into one channel of contextualization
Multi-head attention loosens that bottleneck.
Instead of doing:
- one big attention computation over the whole model dimension
the Transformer does:
- several smaller attention computations over different learned subspaces
This creates two big benefits:
- different heads can capture different relationship patterns
- the model can represent several contextual views at once before recombining them
Operational payoff:
- richer token interactions
- more expressive contextualization per layer
- better fit for complex structure than a single monolithic attention map
Learning Objectives
By the end of this session, you should be able to:
- Explain why multi-head attention exists as a way to avoid forcing one head to represent every useful relationship.
- Describe the multi-head computation from per-head projections through attention, concatenation, and final output projection.
- Evaluate what multiple heads buy and what they cost, including expressiveness, parameter use, head redundancy, and compute overhead.
Core Concepts Explained
Concept 1: One Attention Head Is Often Too Narrow a Bottleneck
Concrete example / mini-scenario: In the sentence "The scientist who reviewed the paper said it was flawed," a useful model may need to track:
- local agreement
- clause structure
- long-distance reference
- which token
itis likely talking about
Intuition: A single attention pattern can be informative, but it must collapse many different kinds of relevance into one learned view. That is restrictive.
Technical structure (how it works): A single head builds one Q, K, V projection set and produces one attention distribution per position. That means one representational subspace is doing all the matching and all the mixing for that head.
Practical implications:
- a single head may capture useful structure, but not all structure equally well
- the model benefits from being able to distribute relational work across several views
- deeper layers help, but the per-layer attention mechanism itself also matters
Fundamental trade-off: Keeping one head is simpler, but it limits how many distinct interaction patterns the layer can model at once.
Mental model: Asking one analyst to summarize a complex situation from a single angle versus letting several analysts inspect it from different perspectives and then combining their reports.
Connection to other fields: Similar to ensemble reasoning and feature subspaces. Multiple partial views often represent complex structure better than one oversized projection.
When to use it:
- Best fit: tasks where several relation types matter simultaneously across the same sequence.
- Misuse pattern: assuming one large head is automatically equivalent to several smaller specialized heads.
Concept 2: Mechanically, Multi-Head Attention Splits the Representation into Several Learned Attention Paths
Concrete example / mini-scenario: Suppose the model dimension is d_model = 512 and there are 8 heads. Each head might operate on a smaller d_k = d_v = 64 subspace.
Intuition: Instead of one big attention operation over the full space, the layer runs several attention operations over different learned projections of the same input.
Technical structure (how it works):
For each head h, the model learns separate matrices:
Q_h = XWq_h
K_h = XWk_h
V_h = XWv_h
Then each head computes its own attention output:
head_h = Attention(Q_h, K_h, V_h)
After all heads finish:
concat = [head_1 ; head_2 ; ... ; head_H]
output = concat Wo
So the structure is:
- project into multiple per-head subspaces
- run attention independently in each head
- concatenate head outputs
- apply one output projection to fuse them
Practical implications:
- each head can learn a different matching behavior
- the output projection lets the model recombine those views into one representation
- the layer stays fully differentiable and trainable end to end
Fundamental trade-off: You gain richer structure, but add more parameters, more tensor movement, and more places where compute is spent.
Mental model: Several search engines ranking the same document set with different criteria, followed by one fusion step that combines the results.
Connection to other fields: This resembles mixture-of-features thinking more than strict modular decomposition. Heads are not manually assigned roles, but the architecture gives room for partial specialization.
When to use it:
- Best fit: Transformer blocks that need several relational views in parallel.
- Misuse pattern: treating the number of heads as pure free capacity with no cost or diminishing returns.
Concept 3: Multiple Heads Add Expressiveness, but Not Every Head Is Equally Useful
Concrete example / mini-scenario: In trained models, some heads clearly capture meaningful patterns, while others appear weak, redundant, or easy to prune with little loss.
Intuition: Multi-head attention creates the opportunity for specialization, not a guarantee that every head will become equally important.
Technical structure (how it works): During training, gradients shape each head's projection matrices and attention behavior. Some heads may specialize in local structure, separators, positional shortcuts, or semantic relations. Others may converge to overlapping or low-signal behavior.
This leads to a practical truth:
- more heads can help
- but "more heads" is not the same as "strictly better model"
Quality depends on:
- model dimension
- head dimension
- task complexity
- depth
- optimization dynamics
Practical implications:
- head count is a design knob, not a magic upgrade
- some heads may be redundant
- efficiency work later often studies pruning, compression, or alternative attention patterns
Fundamental trade-off:
- richer representation and relational diversity
- potential redundancy, extra memory traffic, and more expensive inference
Mental model: A team with several specialists can outperform one generalist, but only if those specialists contribute distinct signal instead of duplicating each other.
Connection to other fields: Similar to systems parallelism and organizational design: splitting work into parallel units only helps if those units actually reduce bottlenecks or cover different needs.
When to use it:
- Best fit: standard Transformer blocks where broad contextual expressiveness matters.
- Misuse pattern: increasing heads blindly without considering per-head dimension, hardware cost, or redundancy.
Troubleshooting
Issue: "Why not just make one head bigger instead of adding many heads?"
Why it happens / is confusing: Bigger often sounds like strictly more capable.
Clarification / Fix: A bigger single head still produces one attention pattern per token. Multi-head attention creates several independently learned patterns, which is a different kind of capacity.
Issue: "Do heads have fixed human-interpretable roles?"
Why it happens / is confusing: Visualizations sometimes make heads look like clean labeled modules.
Clarification / Fix: Some heads show interpretable patterns, but specialization is learned and partial. Heads can overlap, shift roles by layer, or look noisy while still helping the model.
Issue: "If some heads are redundant, does that mean multi-head attention was a bad idea?"
Why it happens / is confusing: Redundancy can sound like waste.
Clarification / Fix: Not necessarily. The architecture gives the model room to discover useful decompositions. Some redundancy is common in expressive systems; the key question is whether the overall layer gains useful capacity.
Advanced Connections
Connection 1: Multi-Head Attention <-> Feature Subspaces
The parallel: Each head is a learned projection into a different representational subspace, where different similarity patterns can become easier to express.
Real-world case: This is one reason head count, model width, and per-head dimension interact so strongly in Transformer design.
Connection 2: Multi-Head Attention <-> Model Compression
The parallel: Once a model has many heads, later optimization work naturally asks whether all of them are necessary at inference time.
Real-world case: Head pruning and other compression strategies often start from the observation that not all learned heads contribute equally.
Resources
Suggested Resources
- [PAPER] Attention Is All You Need - arXiv
Focus: the original definition of multi-head attention in the Transformer. - [DOC] The Annotated Transformer - Harvard NLP
Focus: a clear code-oriented explanation of how the head split, concat, and projection work. - [PAPER] Are Sixteen Heads Really Better than One? - arXiv
Focus: a useful paper for understanding redundancy and head importance in trained models.
Key Insights
- Multi-head attention exists because one attention map is often too limited to represent all useful relationships at once.
- Mechanically, it runs several smaller attention computations in parallel and then fuses them with an output projection.
- Multiple heads add expressive capacity, not guaranteed perfect specialization, so head count is a design trade-off rather than a free win.