Mixture of Depths: what the paper actually says
Reading Raposo et al. (Google DeepMind, 2024) while profiling per-token latency variance on a production language model.
The profiling session started with an anomaly: our latency distribution had a fat right tail that didn't correlate with output length. We were generating similar-length responses, but some requests took 40% longer than others. After enough instrumentation, the pattern became clear — the expensive requests had longer, more syntactically complex prompts. Shorter, simpler prompts with similar output length were consistently faster. The model was spending the same compute on "the" as it was on "photosynthesis," on punctuation as on multi-step reasoning chains.
This is the problem Mixture-of-Depths is solving. Not the tail latency issue specifically, but the deeper architectural assumption behind it: that uniform compute allocation across all tokens at all layers is the right default.
It isn't, and the paper gives you a concrete mechanism to fix it.
Paper: Mixture-of-Depths: Dynamically allocating compute in transformer-based language models — Raposo, Ritter, Richards, Lillicrap, Humphreys, Santoro, Google DeepMind, 2024.
The problem the paper is actually solving
Standard transformer architecture routes every token through every layer. At each layer, every token participates in self-attention, writes to the KV cache, passes through the feed-forward network, and accumulates in the residual stream. The FLOPs per forward pass scale with N × L, where N is sequence length and L is number of layers. Every token. Every time.
This uniform allocation has a calibration problem. Consider a 32-layer model generating the token "The" at the start of a sentence. The model has already encoded "The" in early layers — it's a common token with stable, context-independent embedding behavior. Running it through 32 layers of full attention and FFN computation is doing something, but most of it is redundant. The same "The" embedding passes through layer 4, layer 8, layer 16, layer 28, barely changing. Contrast this with the token "however" mid-paragraph in an argument — that token's representation needs to integrate context from both sides, understand discourse structure, and condition on the local syntactic frame. It needs all 32 layers.
The paper formalizes this intuition as a routing problem: at each transformer layer, a lightweight router decides which tokens are "capacity" tokens (get the full attention + FFN treatment) and which are "skip" tokens (bypass the block via residual and pass through unchanged). The total compute budget remains predictable because the router is top-k: exactly k tokens are processed at each layer, always. The graph is static in shape even though the routing is dynamic.
The mechanism
The routing decision at each layer uses a learned scalar router: a single linear projection from hidden dimension to a scalar logit, followed by a top-k selection over the sequence.
For a sequence of N tokens at layer l:
- Score: compute logit
r_i = W_r · h_ifor each tokeni, whereh_iis the token's hidden state - Select: take the top-k tokens by
r_i; let C denote this capacity set (|C| = k) - Route capacity tokens: run attention and FFN only over tokens in C
- Skip non-capacity tokens: pass them through unchanged via the residual connection
The capacity tokens' self-attention at layer l only attends to other capacity tokens at that layer. Tokens in C write K and V entries for that layer; tokens not in C don't participate in attention at all. The skip tokens' hidden state is unchanged: h_i^{l+1} = h_i^l for i ∉ C.
At the next layer, the router re-scores all N tokens fresh. A token that was skipped at layer 3 might be selected at layer 4. The routing is not cumulative — every layer makes an independent decision.
The capacity fraction κ = k/N is the core hyperparameter. The paper tests configurations including κ = 0.125 (12.5% of tokens processed per layer), κ = 0.25 (25%), and κ = 0.5 (50%). A model with κ = 0.125 and 32 layers performs only 12.5% of the attention and FFN compute of a standard model at each layer — while the other 87.5% of tokens coast through on residuals.
What the skip tokens actually experience
This is the part that tripped me up reading the paper. If 87.5% of tokens skip each layer, what does their hidden state look like by the final layer?
Skip tokens don't stagnate — they simply carry their representation from the previous layer untouched into the next one. If token i skips layers 3, 5, and 7 but is processed at layers 4, 6, and 8, its final representation is shaped by the layers it participated in. The model learns to allocate its processing budget across layers in a way that achieves good representations at layer L, even though individual tokens take non-uniform paths through the depth.
The paper describes this as the model learning "which tokens require additional processing at each depth." Simple tokens with stable representations learn to be skipped frequently. Tokens that anchor discourse structure, introduce new entities, or appear at syntactically complex positions tend to be routed through more layers. The routing isn't hand-coded; the router weight matrix W_r is learned end-to-end.
What the skip tokens don't get: they cannot attend to capacity tokens at the layers they skip, and capacity tokens cannot attend to skip tokens at those layers. The attention at layer l is a reduced computation over the capacity set only. This changes the information flow relative to standard attention and is the primary reason aggressive capacity fractions (κ < 0.125) can degrade quality.
Training stability and routing collapse
Like mixture-of-experts, MoD is susceptible to routing collapse: the router learns to always select the same tokens, underutilizing most of the model's capacity. The paper addresses this through a combination of:
Auxiliary load balancing loss: an auxiliary loss term penalizes imbalanced routing across layers, encouraging the router to distribute selections more uniformly across positions. Similar to the expert balancing loss in MoE.
Token merging after selection: the paper experiments with whether skipped tokens at a given layer contribute their representations to the capacity tokens' attention context. In the cleanest formulation (MoD without merging), they don't — the capacity tokens only attend to each other. This keeps the computation graph simple and avoids the scatter-gather overhead of mixing capacity and skip tokens in attention.
Independent per-layer routing: because each layer has its own router, collapse at one layer doesn't cascade. A token locked into always being skipped at layer 8 can still be selected at layer 9.
Training on standard language modeling objectives (next-token prediction) is otherwise unchanged. The MoD layers slot directly into a standard transformer architecture; non-MoD layers can coexist with MoD layers in the same model.
What the performance numbers actually show
The paper's primary comparison is isoFLOP: given the same training FLOP budget, does a MoD model achieve better perplexity than a standard transformer? The answer is yes, with some nuance.
With κ = 0.125 and a model trained to match an isoFLOP baseline:
- The MoD model achieves comparable perplexity to the baseline at the same training compute
- Per-inference-step FLOPs are substantially lower — because 87.5% of the attention and FFN compute per layer is skipped
- Sampling speed at inference increases by up to 50% in the paper's measurements
The 50% speedup is in sampling speed — tokens per second during autoregressive decoding — not first-token latency. Because most tokens skip most layers, the per-step compute is reduced, and you generate tokens faster.
The key word in the isoFLOP comparison is "comparable." MoD doesn't strictly dominate a standard transformer at every parameter count and every sequence length. At small models (under ~1B parameters), the router overhead and reduced per-layer communication can hurt more than the skipping helps. At larger models with longer contexts, the efficiency gains become more pronounced. The paper's results are most compelling in the 2B+ parameter range.
MoDE (Mixture of Depths and Experts): the paper also introduces MoDE, which combines per-layer token routing (MoD) with per-token expert routing (MoE). At each MoDE layer, tokens are first selected by the depth router; selected tokens are then routed to one of several experts by the expert router. Skipped tokens bypass both. MoDE achieves the efficiency gains of both mechanisms simultaneously, with the compute savings multiplying. The operational complexity also multiplies, which matters for deployment.
Production tradeoffs no one mentions in the benchmark post
The capacity fraction commits you at deployment time. You train a model with κ = 0.125. That κ is baked into the architecture — the number of top-k selections per layer, the routing weights, all of it is tuned together. At serving time, you can't change κ without retraining. If your production traffic distribution shifts and the 87.5% skip rate is too aggressive for your actual inputs, you discover this post-deployment and the fix is retraining. Unlike speculative decoding where you can tune the draft length at runtime, MoD's routing behavior is a training-time decision.
The static compute graph is a double-edged sword. The paper makes a point of emphasizing that the computation graph is static: tensor sizes are known, k is fixed, the shape of every operation is determined by κ and N. This makes MoD easier to compile and optimize than fully dynamic routing. But it means every forward pass allocates and runs the same sized kernels regardless of whether the routing decisions actually vary. If your input distribution is such that the router almost never varies (consistently selects the same positions), you've paid for the router overhead without getting dynamic behavior.
KV cache implications are non-obvious. In standard attention, the KV cache for autoregressive generation stores one K and V vector per token per layer. In MoD, tokens that skip a layer don't write to the KV cache at that layer. For a model with κ = 0.125 and 32 layers, a token might write to 4-5 layers' KV caches rather than all 32. Total KV cache memory is dramatically reduced — roughly proportional to κ. For long-context serving where KV cache is the primary memory bottleneck, this is significant. But it also means that when a capacity token at layer l attends to "previous tokens" at that layer, it only sees the K/V entries of other tokens that were also capacity tokens at that layer. The effective attention window is smaller than the sequence length. For most language modeling tasks this is fine; for tasks requiring precise long-range dependencies at specific positions, it's a potential failure mode.
Latency predictability improves, but batch efficiency degrades. The original problem I described — fat-tail latency from complex prompts — partially improves: simple tokens skip more layers, reducing compute for easy requests. But in batched inference, a batch containing a mix of simple and complex prompts creates an efficiency problem. All requests in a batch must execute the same number of forward passes (same sequence length, same number of layers). The routing just determines which tokens do FFN/attention at each step. You don't save any wall-clock time from the skip — you save compute — unless your hardware can take advantage of the sparse computation within each batch.
Actually running MoD efficiently requires hardware support for sparse block operations (skip the FFN computation for non-selected tokens). GPU kernels for this are not trivially available. Naive implementations fall back to masking: you compute attention over all tokens, mask out non-selected tokens, zero their contribution. This gets you the correct output but none of the speed. The 50% sampling speedup in the paper requires a custom implementation that actually skips computation for non-selected tokens.
Router overhead at 32 layers adds up. Each layer has a router: W_r · h_i for all N tokens. For a 4096-dimensional model with 32 layers and 2048-token sequences: 32 × 2048 × 4096 additional multiplications. Relative to the attention and FFN compute this is small, but it's non-zero and present for every token regardless of whether it's selected. In κ = 0.125 regimes, the router is computing scores for 87.5% of tokens that will be skipped. The paper's 50% speedup accounts for this overhead, but custom implementations that don't vectorize the router efficiently can eat into the benefit.
Failure modes in practice
Distribution shift degrades routing quality silently. The router learns to identify which token positions need processing for your training distribution. When production input distributions shift — new topics, new languages, unusual formatting — the router's learned heuristics may misfire. Simple tokens that look unusual to the model may be routed as capacity tokens (wasting compute), and complex tokens that pattern-match to common constructions may be skipped (hurting quality). Unlike acceptance rate in speculative decoding, there's no direct metric to monitor routing quality at inference time. You'd need to track downstream task metrics and look for degradation.
The self-attention coverage gap for skipped tokens. When token i is skipped at layers 3-7 and processed at layer 8, its hidden state at layer 8 input is its hidden state from layer 2 output — it's missed five layers of contextual updating. At layer 8, when token i is a capacity token, it attends to other capacity tokens at layer 8. But those other capacity tokens' K/V entries at layer 8 already reflect 5 more layers of updates than token i's hidden state. This creates a representational mismatch: the attending token is "behind" the attended tokens in terms of depth of processing. The model learns to handle this during training, but for tasks requiring tight layer-by-layer coherence (complex multi-step arithmetic, highly structured outputs), this mismatch can produce errors that are hard to diagnose.
Combining with speculative decoding gets complicated. Speculative decoding requires a draft model with closely matched distribution to the target. If the target model uses MoD and the draft model doesn't (or vice versa), the routing patterns differ and acceptance rates drop. A draft model trained with the same κ as the target is the correct setup but requires running both MoD models simultaneously, each with their routing overhead. The operational complexity multiplies faster than the benefits.
Variants worth knowing
Layer-wise vs. alternating MoD: the paper tests both applying MoD to all layers and alternating MoD layers with standard layers (e.g., MoD at every other layer). Alternating reduces the impact of the skip behavior on long-range information flow — skip tokens at a MoD layer get updated at the adjacent standard layer. This makes the quality-efficiency tradeoff less aggressive and is easier to tune. For production deployments where you're nervous about the coverage gap, alternating MoD is a lower-risk starting point.
MoD with token merging: an alternative to the clean skip formulation involves merging skipped tokens' representations into the nearest selected token before attention, then un-merging afterward. This preserves more information flow but adds scatter-gather overhead and complicates the computation graph. The paper finds the clean skip formulation competitive without the merging overhead.
Integration with speculative decoding via Medusa: Medusa adds speculative heads directly to the target model rather than requiring a separate draft model. Combining MoD for per-layer efficiency with Medusa for autoregressive efficiency is architecturally cleaner than combining MoD with a separate draft model. Both mechanisms modify the base model; neither requires a second full model in memory.
When not to use Mixture of Depths
When you need verifiable attention coverage. Tasks where you need to guarantee that specific token pairs interact — structured extraction from documents, precise reference resolution, tasks requiring provably full information flow — are poorly served by a mechanism that conditionally skips attention. You can't inspect the routing after the fact and confirm that the important tokens were selected at every relevant layer.
When latency, not throughput, is your SLO. MoD saves compute per token, but the time savings only materialize if your hardware can execute sparse attention efficiently. On standard GPU clusters running standard kernels, MoD reduces FLOPs but not necessarily wall-clock time. If you're targeting p50 latency (not throughput), the kernel support requirements are prohibitive unless you're building on a framework that already implements sparse block operations for MoD.
When you're serving a single model across highly varied task distributions. A model trained with MoD on a single task type learns routing appropriate for that task. General-purpose assistants facing diverse query distributions — coding, math, creative writing, factual Q&A — have no single routing pattern that's optimal across all query types. Uniform compute (standard transformer) is a reasonable default when you can't characterize your distribution. MoD is most compelling when you have a well-characterized, homogeneous input distribution.
Before you can profile token-level compute utilization. The performance argument for MoD is that most tokens don't need deep processing. If you can't measure this on your actual model and data — if you don't have tooling to observe layer-by-layer representation change per token — you're deploying an optimization for a problem you haven't confirmed you have. Build the measurement infrastructure first.
For models under 1B parameters. The paper's strongest results are at larger scales. At smaller models, the router overhead is proportionally larger, the skip savings are smaller (fewer parameters per layer to skip), and the quality impact of reduced attention coverage is more pronounced. Below 1B parameters, standard architecture with aggressive quantization typically dominates MoD in the efficiency-quality tradeoff.
What the paper actually gives you
MoD is a principled answer to a real problem: uniform compute allocation across all tokens is wasteful, and the waste is proportional to the fraction of your sequence that's "easy." For models with long contexts and skewed token complexity distributions — which describes most production language model workloads — this waste is substantial.
The mechanism is cleaner than it initially appears. A scalar router per layer, top-k selection, skip via residual — it's a few hundred lines of code on top of a standard transformer. The computation graph is static (k is fixed), which keeps it compilable and hardware-efficient in a way that fully dynamic routing isn't. The paper's isoFLOP results show you can get comparable quality to a standard model while burning fewer FLOPs per inference step.
The catch is that the efficiency benefit requires hardware support for sparse block computation, and the right κ is a training-time commitment you can't adjust post-deployment. The 50% sampling speedup headline requires a custom kernel implementation and a deployment setup where you can actually run sparse FFN and attention blocks. Off the shelf, on standard inference frameworks without MoD-specific kernels, you get the FLOP reduction on paper but not the wall-clock improvement.
The latency profiling issue that started this post is illustrative: MoD doesn't directly solve tail latency variance from complex inputs. What it does is compress the compute difference between easy and hard tokens — easy tokens do less work at each layer. The tail might narrow, but the mechanism is indirect, and without dense instrumentation you can't confirm it's helping. Measure first. Then train with MoD if the measurement tells you something useful.
Mixture-of-Depths: Dynamically allocating compute in transformer language models — Raposo, Ritter, Richards, Lillicrap, Humphreys, Santoro. Google DeepMind, 2024.