Switch Transformers: what the sparse MoE scaling paper actually says
Reading Fedus, Zoph & Dean (Google Brain, 2021) after Mixtral-8x7B started OOMing during fine-tuning in ways that T5-XXL never did.
The question that sent me to this paper wasn't "how do I train a trillion-parameter model." It was more mundane: why did a 46.7B-parameter Mixtral model behave more like a 12B model in memory, but more like a 46.7B model in quality? What is actually happening inside these "mixture of experts" architectures that everyone started shipping in 2023?
The answer lives in Switch Transformers — "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity," Fedus, Zoph, and Dean, Google Brain, JMLR 2022 (arxiv 2021). This paper did two things: it made sparse MoE training simple enough that it actually works reliably, and it established the routing design space that every subsequent MoE model has worked within. Mixtral, DeepSeek-V2, Gemini — all are refinements on the routing decisions this paper made.
The problem: dense models are wasteful
A standard transformer uses all its parameters for every token. That's the "dense" in "dense transformer." If you have a 70B parameter model with 80 FFN layers, every token routes through all 80 FFN sublayers, and each sublayer is fully active.
This is straightforwardly wasteful from a compute perspective. Not every token needs the same computation. The token "the" and the token "gradient" don't require the same knowledge. But dense architectures can't distinguish — every parameter is activated for every token.
The Mixture of Experts idea (Shazeer et al. 2017, "Outrageously Large Neural Networks") was to replace each FFN sublayer with N parallel FFN "experts," plus a routing function that selects which expert(s) process each token. If you route each token to 2 of 64 experts, you've increased total model capacity roughly 32× (64 experts instead of 1), but each token only uses 1/32 of the parameters on its forward pass.
The problem with the 2017 approach was instability. The routing used noisy top-k gating with a separate auxiliary loss to prevent all tokens collapsing onto one expert. It worked, but it was complicated and training was fragile. Teams trying to replicate it reported inconsistent results.
Switch Transformers makes one key simplification: route to exactly 1 expert, always. No noise. No complex gating. Just a linear layer, a softmax, an argmax. This is the "switch" in Switch Transformer.
The architecture, precisely
Each Switch layer replaces an FFN sublayer. The structure:
-
Router: a learned linear projection from hidden dimension
d_modeltoNlogits, followed by softmax. For each tokenx, the router produces a probability distribution overNexperts. -
Expert selection: argmax of the router probabilities. Token
xgoes to expertkwherek = argmax(W_r · x). -
Expert computation: the selected expert (a standard FFN with ReLU) processes the token. The output is scaled by the router probability:
output = router_prob_k * expert_k(x). -
Residual: the expert output is added back to the residual stream, as in standard transformer FFN sublayers.
Not every FFN sublayer in the model is a Switch layer. The paper uses Switch layers every other transformer block (alternating with dense FFN layers). This matters for the failure mode discussion below.
The full Switch-Base model has 128 experts per Switch layer, 12 Switch layers out of 24 total transformer blocks, and 7.4B total parameters — with ~223M parameters active per token (comparable to T5-Base).
The capacity problem, and how they solved it
k=1 routing creates an immediate practical problem: what if most tokens get routed to the same expert?
If you have 128 experts and 1024 tokens in a batch, and 200 tokens all route to expert 3, expert 3 has to process 200 tokens while most other experts process 5. This is fine in theory but breaks parallelism in practice. In a distributed setting where each expert lives on a different device, you need expert 3 to hold a buffer for 200 tokens while other devices wait. This blows up memory allocation and synchronization.
The solution is expert capacity: each expert is allocated a fixed buffer of C tokens, where:
C = (tokens_per_batch / num_experts) * capacity_factor
A capacity_factor of 1.0 means each expert gets exactly its "fair share" of tokens. A capacity_factor of 1.25 means each expert gets a 25% buffer above fair share.
If more than C tokens are routed to an expert, the overflow tokens are dropped: their hidden states are passed through unchanged via the residual connection, skipping the expert computation entirely. The token is processed; it just doesn't go through any expert for that layer.
This is a deliberate design choice that trades computation quality for hardware efficiency. The paper uses capacity_factor=1.25 during training and capacity_factor=2.0 during inference. At inference, you can afford more buffer because you're not constrained by batch size in the same way.
The token drop rate is a metric you actually want to track. High drop rates (>10–15% of tokens) indicate that load balancing is failing — tokens are routing unevenly, some experts are being systematically overwhelmed, and model quality degrades because those tokens aren't receiving expert computation.
The load balancing loss
The router has a differentiation problem. You want uniform token distribution across experts. But the quantity you'd naturally optimize — "fraction of tokens routed to each expert" — isn't differentiable because argmax isn't differentiable.
The paper's solution uses a differentiable approximation of the fraction:
L_aux = α * N * Σ_{i=1}^{N} f_i * P_i
Where:
f_i= fraction of tokens dispatched to expertiin the batch (computed via token counts, not differentiable)P_i= mean of the router's softmax probability for expertiacross all tokens (differentiable)N= number of expertsα= auxiliary loss coefficient, set to0.01in the paper
The product f_i * P_i is high when both many tokens route to expert i AND the router assigns high probability to expert i. Minimizing this sum pushes toward uniform routing.
The N * Σ normalization ensures the loss magnitude is roughly constant regardless of number of experts. The α = 0.01 is small enough that the auxiliary loss doesn't overwhelm the primary language modeling loss but large enough to prevent expert collapse.
Getting α right matters more than it looks. Too small and you get expert collapse (training is stable but most tokens route to 2–3 experts, capacity is wasted). Too large and the model over-invests in routing uniformity at the expense of model quality — the experts become interchangeable and specialization disappears.
Training instability and the FP16 problem
The main reason the 2017 MoE work didn't become standard practice was training instability. Switch Transformers still has this problem, but the paper characterizes it precisely and gives concrete fixes.
The instability mechanism: the router computes softmax(W_r · x). With N = 128 experts, this is a 128-dimensional softmax. During early training, W_r · x produces logits that span a wide range, causing softmax overflow in float16 — the maximum logit exceeds float16's range, the exponential overflows to infinity, the gradient is garbage, and training diverges.
The paper's fixes:
1. Selective float32 for routing. Even when the rest of the model trains in bfloat16, compute the router's softmax in float32. Cast back to bfloat16 before multiplying by the expert output. This adds minimal compute overhead but eliminates the overflow.
2. Smaller weight initialization. The paper initializes router weights with σ = 0.1 × (standard T5 init). Smaller initial weights produce logits with smaller magnitude, reducing the probability of early overflow and giving the auxiliary loss more time to establish load balance before routing decisions crystallize.
3. Expert dropout. Apply higher dropout rates (0.1–0.4) inside expert FFNs than in the rest of the model (0.1). The intuition is that individual experts should generalize, not memorize. Expert dropout reduces the per-expert overfitting that makes fine-tuning unstable.
These three changes together take Switch Transformer training from "frequently diverges" to "reliable enough to be the standard approach." But they need to be applied together — selective float32 alone isn't sufficient if initialization is too large.
What the performance numbers actually show
The primary result: Switch-Base achieves 7× speedup over T5-Base in pre-training steps-to-quality. Both models have roughly the same FLOPs per token; the Switch model just has more parameters and therefore more capacity per FLOP.
This is the key frame for understanding MoE speedups: you're not reducing compute per token, you're increasing parameter count without increasing compute per token. A Switch model with 7B total parameters but 220M active parameters achieves quality comparable to a dense 1.1B model (T5-Large) while having the same per-token compute as T5-Base.
The paper also reports Switch-Base vs T5-XXL (11B dense parameters) on SuperGLUE. Switch-Base exceeds T5-XXL on several benchmarks while using dramatically fewer active parameters per token. This is the quality-per-FLOP efficiency that makes MoE attractive.
Fine-tuning results are more nuanced. Switch-Base fine-tuning on small datasets (GLUE tasks with ~67K examples) underperforms T5-Base. The paper attributes this to regularization mismatch: the pre-training expert diversity that helps generalization hurts when fine-tuning data is scarce. Expert dropout at 40% — appropriate for pre-training — over-regularizes for small-dataset fine-tuning. They reduce expert dropout to 10% for fine-tuning and recover performance, but the sensitivity is real.
The largest model in the paper, Switch-C, has 2048 experts and 1.6 trillion total parameters. Active parameters per token: ~1.2B. They train this on 2048 TPU v3 cores with one expert per core. It outperforms T5-XXL on all 11 language tasks reported. But it requires 2048 TPUs in a specific topology to achieve this — not a model you can reproduce on a standard cluster.
Production tradeoffs no one mentions in the benchmark post
All-to-all communication is the real bottleneck. In a distributed MoE setup, each expert lives on a different device. Dispatching tokens to their experts requires all-to-all communication: every device sends tokens to every other device (for the tokens that got routed to remote experts) and receives tokens from every other device. Then after expert computation, you do a second all-to-all to gather results.
This means each Switch layer requires two all-to-all collectives that scale with O(N * d_model * capacity) in message volume. For 128 experts across 128 GPUs with d_model = 1024 and capacity_factor = 1.25, you're moving roughly 160K floats per device per Switch layer per batch. At 12 Switch layers, this is non-trivial — the all-to-all can become the wall-clock bottleneck, especially if interconnect bandwidth (NVLink, InfiniBand) is limited.
The implication: MoE models are more sensitive to topology than dense models. A Switch Transformer trained on 128 TPU cores with 900 GB/s interconnect will have a different effective FLOP/s than the "same" model running on 128 A100s connected via InfiniBand at 200 GB/s. The benchmarks in the paper are TPU benchmarks. GPU-side performance varies.
Expert specialization is fragile early in training. For the first several hundred steps, experts don't specialize — the routing loss hasn't had time to establish stable patterns and the router is close to random. During this window, expert utilization can be wildly uneven, drop rates are high, and the training loss has a characteristic "instability bump" before stabilizing. If you checkpoint during this window and resume from it, the routing can re-randomize. Checkpointing only after the routing has stabilized (watch the expert utilization distribution) avoids unnecessary recovery pain.
Capacity factor creates a hidden memory-throughput tradeoff. Higher capacity factor means fewer dropped tokens (better quality) but more memory allocated per expert buffer. With capacity_factor = 2.0 at inference and 128 experts, you're holding 2× "fair share" token buffers on every device simultaneously. For large batch sizes, this can be the binding memory constraint — not the model weights. Teams that size their inference memory budget around the parameter count (which stays fixed) and not the capacity buffers (which scale with batch size) hit OOM in production when batch sizes vary.
Dropping tokens isn't logged by default. When tokens exceed expert capacity and get dropped, there's no exception, no warning, no gradient spike. The model just processes them with the residual connection and moves on. High token drop rates (>15%) can silently hurt quality while all your other metrics look normal. Add explicit monitoring of token_drop_rate per expert per layer — most training frameworks don't surface this by default.
Failure modes in practice
Expert collapse is the most common failure for teams running their own MoE training. It looks like this: training proceeds normally for thousands of steps, then a few experts accumulate most tokens, other experts receive near-zero tokens, the auxiliary loss starts fighting this, and the training loss gets stuck at a suboptimal plateau. The fix is to restart training with a larger α for the auxiliary loss, and to verify that the float32 routing fix is actually applied — it's easy to have it theoretically enabled but running in bfloat16 due to an autocast scope mistake.
Fine-tuning instability on small datasets is the next most common failure. You've pre-trained a Switch model with expert dropout at 40%, then fine-tune on 10K examples. The model underperforms a much smaller dense model. This isn't a bug; it's the regularization mismatch the paper describes. Reduce expert dropout to 0–10% for fine-tuning and check that the auxiliary loss α is also reduced — the load balancing pressure appropriate for pre-training over-constrains expert selection when the fine-tuning distribution is narrow.
Silent quality degradation from dropped tokens in production inference. You set capacity_factor = 1.25 to save memory. Your average batch is 128 tokens and you have 64 experts. Fair share is 2 tokens per expert. But your production traffic has long prompts with unusual token distributions, and certain expert positions consistently overflow. Tokens are being dropped, the model handles it via residual, and outputs are slightly but consistently worse on those prompts. You won't see this unless you're monitoring token_drop_rate separately from loss.
When not to use Switch Transformers
When you're training on a single node or small cluster. MoE's efficiency depends on expert parallelism — having enough devices that each expert gets its own device. If you can't assign at least one expert per device, you're simulating expert parallelism on shared hardware, the all-to-all communication still occurs (now within-device, which is cheaper, but the buffers still bloat), and the memory overhead of holding unused expert buffers cancels out much of the parameter efficiency. The break-even point depends on your hardware but is roughly 4+ GPUs for practical expert counts.
When you need low-latency single-request inference. Batch-1 inference on MoE models is brutal. The all-to-all dispatch for a single token still allocates full expert capacity buffers (since the buffer size is fixed at serving startup), loads the relevant expert's weights from device memory, and routes. You're paying the routing overhead while getting minimal benefit from expert specialization. Dense models at the same quality level typically have lower p99 latency for single requests.
When your fine-tuning data is small and you can't tune expert dropout carefully. If you're planning to fine-tune on <50K examples and you don't have the compute budget to sweep dropout values, a smaller dense model fine-tuned carefully will beat a larger Switch model fine-tuned hastily. The parameter count advantage of MoE evaporates when the fine-tuning distribution doesn't have enough signal to drive useful expert specialization.
When you can't monitor expert utilization. If your training infrastructure doesn't expose per-expert utilization metrics, you're flying blind on the failure modes that actually matter for Switch models. Expert collapse and high token drop rates are silent; you need explicit instrumentation or you'll waste weeks of training on a silently degraded run.
What the paper actually gives you
Switch Transformers resolved the question "is sparse MoE actually viable at scale" with a clear yes — but a conditional yes. The conditions are: distributed training with expert parallelism across enough devices, careful routing in float32, auxiliary loss tuning, explicit monitoring of expert utilization and token drop rates, and fine-tuning workflows adapted for higher baseline dropout.
The design choices in this paper — k=1 routing, capacity factor, the specific form of the auxiliary loss, the float32 routing fix — are exactly the choices that appear in Mixtral's routing, in DeepSeek-V2's multi-latent MoE design, in Gemini's expert configuration. They all refined and extended Switch Transformers. Reading the paper gives you the vocabulary to understand what those refinements are doing: Mixtral adds top-2 routing back (trading communication for quality), DeepSeek-V2 introduces fine-grained experts with much higher N, Gemini uses different capacity management. Without the Switch baseline, those differences look like arbitrary choices. With it, they're legible tradeoffs.
For your specific situation: if you're evaluating a pre-trained MoE model for inference, the things to measure are per-request latency under realistic batch size distributions (not just throughput), actual token drop rates under your specific prompt distributions (not just average), and memory under realistic maximum batch sizes (not just parameter count). Those three metrics tell you whether MoE's paper efficiency translates to your deployment context.
The Mixtral OOM that started this? Expert capacity buffers allocated at capacity_factor=2.0 with a batch size we hadn't anticipated. Once we set capacity_factor=1.25 and added token drop monitoring, the memory footprint became predictable and the drop rate was under 2%. The quality difference between 1.25 and 2.0 capacity factors was not measurable in our evals.
Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity — Fedus, Zoph, Dean. JMLR 2022.