← Back to writing

GQA: what the grouped query attention paper actually says

Reading Ainslie et al. (Google, EMNLP 2023) while debugging why LLaMA-2 70B serving saturated at 11 concurrent requests on a 2×A100 node.

We'd done everything right. vLLM with 16-token blocks, a 90% KV cache utilization admission threshold, prefix caching for the system prompt. The 13B model we'd started with scaled nicely. When we switched to the 70B variant for quality, the same serving configuration hit a ceiling at 11 concurrent 4K-token requests. GPU memory was the constraint again, but this time PagedAttention wasn't the answer — we were already using it. The problem was upstream of memory management: the KV cache for a 70B model with 64 attention heads and 80 layers is simply enormous, regardless of how cleverly you allocate it.

The paper that explains why modern large models don't have this problem — and how to fix it in models that do — is "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints," Ainslie, Lee-Thorp, de Jong, Zelaski, Sanghai, and Xu, Google, EMNLP 2023. The contribution is a mechanism called Grouped Query Attention that sits between two existing extremes: standard Multi-Head Attention (expensive KV cache, high quality) and Multi-Query Attention (cheap KV cache, noticeable quality drop). LLaMA-2 70B, Mistral 7B v0.1, Gemma, and most production-grade models released after 2023 use GQA. Understanding it is understanding why those models are deployable at reasonable batch sizes.

The problem, quantified

To understand why GQA matters, work through the KV cache math for a model that doesn't use it.

Standard Multi-Head Attention (Vaswani et al., 2017) gives each attention head its own independent key and value projection matrices. During autoregressive decoding, every new token appends a new key vector and value vector to the cache for every head at every layer. For a model with H heads, head dimension d_k, and L layers:

KV cache per token = 2 × H × d_k × L × sizeof(dtype)

For LLaMA-2 70B (H=64 heads, d_k=128, L=80 layers, FP16):

2 × 64 × 128 × 80 × 2 bytes = 2,621,440 bytes ≈ 2.5 MB per token

At 4,096-token context, that's 10.7 GB per concurrent request in KV cache alone. On a 2×A100 node (160 GB total), with model weights consuming ~140 GB in FP16, you have roughly 20 GB left for KV cache — enough for about 1.8 complete sequences at 4K context. PagedAttention reduces fragmentation and enables batching, but it can't create memory that doesn't exist. The ceiling isn't the allocator; it's the head count.

Now run the same math with GQA-8 — 8 KV groups for 64 query heads, which is what LLaMA-2 70B actually uses:

2 × 8 × 128 × 80 × 2 bytes = 327,680 bytes ≈ 320 KB per token

At 4,096 tokens: 1.34 GB per concurrent request. The same 20 GB spare capacity now supports ~15 concurrent sequences at 4K context before hitting PagedAttention's managed eviction. On identical hardware, serving throughput increases by roughly 8× for KV-cache-limited workloads.

Multi-Query Attention: why we didn't just use that

The idea of sharing key and value heads across query heads isn't new to GQA. Shazeer proposed Multi-Query Attention (MQA) in 2019 ("Fast Transformer Decoding: One Write-Head Is All You Need"). MQA takes the sharing to the extreme: all H query heads share a single key head and a single value head.

MQA KV cache per token = 2 × 1 × d_k × L × sizeof(dtype)

For the 70B model: 2 × 1 × 128 × 80 × 2 bytes = 40,960 bytes ≈ 40 KB per token. At 4K context: 164 MB per sequence. The throughput implications are extraordinary — you could run 120 concurrent sequences before exhausting the 20 GB spare capacity.

The problem is quality. The paper (and practical experience with MQA models) documents consistent degradation on tasks that benefit from attention diversity: long-document question answering, multi-hop reasoning, complex summarization. The intuition is straightforward: if all query heads are attending using the same key and value representations, the model can't learn different types of attention patterns in the same layer. MHA's expressiveness comes partly from different heads specializing — some attend to local syntactic structure, others to long-range coreference, others to positional patterns. With one shared KV head, all H query projections are learning to do different things with identical key-value inputs, which is a weaker inductive bias.

The PaLM team reported ~1–3% degradation on many benchmarks when using MQA compared to MHA at equivalent scale. Google's T5 1.1 used MQA and performed noticeably worse than MHA T5 on certain summarization tasks. The throughput gain was real; so was the quality cost.

GQA: the mechanism

GQA generalizes MQA by introducing G groups, where each group gets its own K head and V head, and H/G query heads share within each group.

MHA:  G = H   → each query head has its own K/V head
MQA:  G = 1   → all query heads share one K/V head
GQA:  1 < G < H → G groups, H/G query heads per group

The computation for a query head in group g:

# Query head i belongs to group g = i // (H // G)
# K and V for group g:
K_g = X @ W_K_g    # W_K_g: d_model × d_k
V_g = X @ W_V_g    # W_V_g: d_model × d_v

# Attention for head i (in group g):
scores_i = (X @ W_Q_i) @ K_g.T / sqrt(d_k)
attn_i = softmax(scores_i) @ V_g

The output for each query head still depends on that head's unique query projection W_Q_i, so query heads within the same group can attend to different regions — they just attend using the same key and value representations. The diversity of attention patterns is reduced but not eliminated.

The KV cache with GQA-G:

KV cache per token = 2 × G × d_k × L × sizeof(dtype)

Choosing G=8 for the 70B model gives an 8× KV cache reduction vs. MHA with quality close to MHA — not the 64× reduction of MQA, but with substantially less quality impact.

The paper's quality results on T5-sized models are the anchor here. GQA with G=8 achieves an average score within 0.3–0.8% of full MHA on the GLUE/SuperGLUE benchmarks used. MQA is 1.0–2.5% below MHA on the same benchmarks. The quality gap between GQA-8 and MHA is largely within noise; the gap between MQA and MHA is not.

The uptraining result: the part of the paper that's actually surprising

The inference efficiency argument for GQA is compelling but not novel — it's a principled interpolation between MHA and MQA. What makes the paper genuinely interesting to me is the uptraining procedure.

Training a large model from scratch with GQA is straightforward — you define fewer KV projection matrices and proceed normally. But most teams deploying LLMs in 2023 weren't training from scratch. They had existing MHA checkpoints: fine-tuned weights, RLHF-adjusted models, domain-adapted variants. The question was whether you could convert an existing MHA model to GQA without re-training from scratch.

The paper's answer is yes, and the procedure is simple:

Step 1: Weight projection. For each group g, mean-pool the K projection matrices of the H/G query heads that will join that group:

# Convert MHA → GQA-G by mean-pooling projection weights
for g in range(G):
    heads_in_group = range(g * (H // G), (g + 1) * (H // G))
    W_K_g = mean([W_K_i for i in heads_in_group], axis=0)
    W_V_g = mean([W_V_i for i in heads_in_group], axis=0)

Step 2: Uptrain. Continue pre-training on a fraction of the original training data — the paper uses ~5% of the original token budget. This allows the model to adapt the query projections to the new grouped K/V structure.

The result: uptraining GQA from a T5-Large MHA checkpoint for 5% of original training tokens recovers quality within 0.5% of a model trained with GQA from scratch on the full token budget. The mean-pooled initialization is load-bearing — random initialization for the new KV heads requires substantially more uptraining to recover quality.

This is operationally significant. The LLaMA-2 team explicitly chose GQA for the 70B variant (not 7B or 13B) and trained it from scratch. But teams working with existing checkpoints — fine-tuned LoRA adapters, RLHF models, proprietary fine-tunes — can apply the uptraining procedure to convert their MHA models to GQA without starting over.

What the performance numbers actually show

The paper benchmarks inference latency for T5 XXL (11B parameters) on a single TPUv4 chip, comparing MHA, GQA-G (varying G), and MQA.

Decoding throughput (tokens/second):

  • MHA: baseline
  • GQA-8: ~19% faster than MHA at batch size 32
  • GQA-4: ~26% faster
  • MQA: ~31% faster

The throughput gain is not linear in the KV cache reduction because the attention computation itself (the QK^T softmax V matmul) is not the only decode bottleneck. At batch sizes where the model is weight-memory-bound (loading projection weight matrices dominates), reducing the KV heads helps less than expected. At batch sizes where KV cache bandwidth dominates — typically larger batches with longer sequences — the improvement scales more directly with the KV cache reduction.

Quality vs. G (T5 XXL, SuperGLUE):

  • G=H (MHA): 89.6
  • G=8: 89.3 (−0.3)
  • G=4: 88.9 (−0.7)
  • G=2: 88.1 (−1.5)
  • G=1 (MQA): 87.2 (−2.4)

The quality-efficiency frontier bends sharply between G=1 and G=2, and flattens considerably between G=4 and G=8. For most production use cases, G=8 is the practical operating point — it captures most of the efficiency gain while losing essentially nothing on quality benchmarks.

Production tradeoffs the benchmark post doesn't mention

The group count needs to divide evenly into the head count. This sounds trivial but creates real constraints. LLaMA-2 70B has 64 heads and uses G=8. If you're adapting an architecture with 40 heads (e.g., LLaMA-1 13B), you need G ∈ 40. G=8 reduces the KV cache by 5×; G=10 reduces it by 4×. The common G=8 assumption from the 70B literature doesn't transfer. Architects often pick H to be a power of 2 specifically to preserve GQA flexibility.

GQA is not free during prefill. The KV cache reduction matters at decode time — that's when the cache is read. During prefill (processing the input prompt), GQA still computes K and V projections for every token, just with fewer projection matrices. The compute savings at prefill are modest (reduced projection FLOPs by G/H). For workloads where prefill dominates (RAG with 8K-token retrieval, code context, document processing), GQA's primary benefit — smaller KV cache — is less relevant because the KV cache doesn't persist across requests.

Multi-GPU tensor parallelism interacts with G. In tensor-parallel inference (common for 70B+ models), attention heads are sharded across GPUs. Each GPU handles H/N heads where N is the tensor parallelism degree. For GQA to work correctly with tensor parallelism, each GPU's shard of query heads must correspond to an integer number of KV groups. For LLaMA-2 70B with G=8 and TP=8: each GPU handles 64/8=8 query heads and 8/8=1 KV group. Each GPU handles exactly one KV group, which is clean. For TP=4: 16 query heads per GPU, 2 KV groups per GPU — also clean. For TP=5 with H=64, G=8: 12.8 heads per GPU, which doesn't divide evenly. Choosing H and G must account for planned parallelism strategies. This is a non-obvious architectural constraint that matters when scaling.

Converting LoRA-fine-tuned models to GQA is not straightforward. LoRA adds low-rank updates to attention projection matrices. If you uptrain the base model to GQA and then want to apply a LoRA adapter fine-tuned for the original MHA checkpoint, the adapter dimensions are incompatible — the MHA adapter assumed H K/V heads, the uptraining checkpoint has G. You need to either re-run LoRA fine-tuning after the GQA uptraining (using the uptraining checkpoint as the base), or merge the LoRA adapter first and then do GQA uptraining (which re-trains the merged weights). Neither option is free. If you're operating a model with multiple LoRA adapters — a common production pattern for multi-tenant fine-tuning — GQA conversion multiplies your fine-tuning compute budget.

Flash Attention 2 and GQA. The FlashAttention-2 kernel has explicit support for GQA (added after the original GQA paper). Without kernel support, implementing GQA naively requires expanding the G KV heads to match the H query heads before the attention computation — which negates the KV cache memory savings. Confirmed that your inference stack's attention kernel supports GQA natively before deploying; without it, you get the quality tradeoff without the performance benefit.

Failure modes in practice

Head collapse at low G. When G is small — G=2 or G=1 — some query heads within a group can learn to effectively ignore the attention output and rely entirely on the feed-forward sublayer for token mixing. When this happens, the remaining query heads in the group carry an asymmetric attention burden, attention entropy becomes lopsided, and the model's effective expressiveness is lower than the architecture suggests. This is diagnosable: log the per-head attention entropy distribution during inference. Healthy multi-head attention shows diverse entropy across heads; collapsed groups show near-zero entropy on some heads and high entropy on others. The fix is either increasing G or adjusting the training regime.

Uptraining instability from mean-pool initialization. The mean-pooled KV projection weights are a reasonable initialization for uptraining, but the learning dynamics can be fragile if the uptraining data distribution differs significantly from the original pre-training distribution. One failure mode I've observed: teams uptrain an MHA foundation model on a domain-specific corpus to get GQA, then discover the uptraining has shifted the model's behavior on out-of-domain prompts. The mean-pool initialization preserves weight scale, but 5% token budget may be insufficient to fully adapt the KV projections while preserving general capability. Monitoring perplexity on a held-out general-purpose evaluation set throughout uptraining — not just domain eval — is necessary.

Tensor parallelism bugs from incorrect KV head sharding. The most insidious failure mode: a GQA implementation that's incorrect under tensor parallelism due to the KV head resharding logic. The symptom is deterministic quality degradation that only appears at TP≥2, with outputs that are coherent but subtly wrong. In one case I debugged, the KV head shard assignment interleaved heads across groups instead of assigning whole groups to each GPU — query head 0 from group A shared K/V with query head 0 from group B, silently. The outputs were plausible because the model partially compensated during generation. Catching this requires testing TP=1 and TP=N outputs on identical prompts with greedy decoding; any divergence is a KV sharding bug.

When not to use GQA

Models below ~30B parameters where KV cache isn't the dominant constraint. LLaMA-2 7B and 13B use standard MHA. For a 7B model (32 heads, 128 head_dim, 32 layers), the KV cache is:

2 × 32 × 128 × 32 × 2 bytes ≈ 524 KB per token
At 4K context: ~2.1 GB per sequence

On a single A100 80GB, model weights occupy ~14 GB in FP16. You have 66 GB for KV cache — room for ~31 concurrent 4K-token sequences without any help from GQA. The return on switching to GQA-8 is a ~4× KV cache reduction, which increases concurrent capacity from 31 to ~125. That's meaningful at scale, but the model quality impact (even 0.3%) may matter more than the throughput gain for use cases where the 7B is already quality-constrained.

Prefill-dominated workloads. If you're processing documents, running RAG, or serving code-completion requests with 8K+ token inputs and short outputs (50–100 tokens), the KV cache for the completed output is small compared to the context size — and you're discarding the context KV cache after each request anyway (not serving a stateful session). Your bottleneck is prefill compute, which GQA doesn't improve meaningfully. Profile your TTFT (time to first token) vs. TPOT (time per output token) split. If TTFT dominates, GQA doesn't address your actual bottleneck.

When you're evaluating model quality for high-stakes tasks before committing to an architecture. GQA's benchmark numbers (0.3–0.8% degradation) are averages across standard academic benchmarks. For tasks with unusual attention patterns — long-range coreference in very long documents, multi-document comparison, precise numerical reasoning with many intermediate steps — the degradation can be larger than the average suggests. Benchmark on your actual task distribution before selecting G. The right G for a coding assistant (where LLaMA-2 70B GQA-8 is excellent) may not be the right G for a medical record analysis system.

When adapting models with dense adapter ecosystems. If your workflow relies on a large library of LoRA adapters trained against an MHA checkpoint — multi-tenant fine-tuning, per-customer adapters, personalization — GQA uptraining forces you to retrain all adapters. The one-time uptraining cost is manageable; the adapter ecosystem migration is not. In this case, PagedAttention and continuous batching optimizations may yield more practical throughput improvement without requiring adapter retraining.

What the paper actually gives you

GQA's contribution is smaller than it first appears on paper and larger than it appears in benchmarks.

The mathematical idea — sharing K/V heads across groups of query heads — is obvious in retrospect, as most good ideas are. The non-obvious contributions are: (1) that the quality-efficiency frontier is as favorable as it is — G=8 nearly matches MHA while providing 8× KV cache reduction; (2) that the mean-pool uptraining procedure works well enough to convert existing checkpoints without full retraining; and (3) that the paper provided clear enough engineering guidance that the LLaMA, Mistral, and Gemma teams all implemented it within months.

The practical upshot is that every production deployment of a modern 30B+ model is implicitly a GQA deployment. When you configure vLLM to serve LLaMA-2 70B and hit PagedAttention's block management, the reason your KV cache is manageable at all is GQA. The 2×A100 LLaMA-2 70B deployment that serves 15 concurrent 4K-token requests would require 4×A100s under MHA at equivalent batch size — not because of any algorithmic inefficiency, but because H=64 heads per layer produce an immovable memory wall.

For that 70B serving problem I started with: switching from a custom MHA 70B checkpoint to LLaMA-2 70B (which uses GQA-8 natively) increased concurrent request capacity from 11 to 47 at the same 4K-context SLA. The difference wasn't serving infrastructure — it was the number of K/V projection heads.


GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints — Ainslie, Lee-Thorp, de Jong, Zelaski, Sanghai, Xu. Google. EMNLP 2023.