FlashAttention: what the IO-aware paper actually says
Reading Dao et al. (Stanford 2022) while debugging memory OOM errors on 8K-context inference.
The error was straightforward: CUDA out of memory. Tried to allocate 12.50 GiB. We were running inference on a 13B model with 8,192-token contexts on an A100. The model fit. The KV cache fit. Attention didn't.
Standard attention allocates an N×N matrix where N is sequence length. At 8K tokens, that's 8192×8192 = 67 million float32 values — 256 MB just for the attention scores, per layer, times 40 layers. You hit the wall fast, and the wall doesn't move.
The standard response was to use approximate attention: locality-sensitive hashing (Reformer), random features (Performer), linear approximations. The problem is that approximate attention changes model behavior in ways that are hard to characterize, and you're no longer running the same computation as the model was trained on. Teams would implement approximate attention in their serving stack and then spend weeks figuring out why outputs degraded on specific prompt patterns.
FlashAttention — "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness," Dao et al., Stanford 2022 — solves the memory problem without approximation. It computes exact attention, identical to standard attention, but uses O(N) memory instead of O(N²). The insight isn't a new mathematical approximation. It's a different way of thinking about what the bottleneck actually is.
The problem the paper is actually solving
To understand FlashAttention, you need to start with the GPU memory hierarchy.
A modern A100 GPU has two relevant memory types:
- HBM (High Bandwidth Memory): 80 GB, ~2 TB/s bandwidth
- SRAM (on-chip cache, shared memory): ~20 MB total, ~19 TB/s bandwidth
SRAM is roughly 10× faster than HBM for reads and writes, but 4000× smaller. The GPU executes most of its matrix multiplications (GEMM operations) with high utilization because GEMM can be structured so data lives in SRAM for reuse. Attention in the standard implementation is different — it explicitly materializes the full N×N score matrix to HBM, then reads it back to apply softmax, then writes the result back again before computing the output.
This isn't a computation problem. The floating-point operations required for attention are straightforward. It's a memory movement problem: the algorithm unnecessarily moves data through the slow HBM multiple times. The FLOPs are cheap. The I/O is not.
The paper defines this formally using the concept of IO complexity — the number of HBM reads and writes, not the number of arithmetic operations. Standard attention has O(N²) IO complexity. FlashAttention achieves O(N²/M) HBM accesses, where M is the size of SRAM. Since M is fixed hardware, this means FlashAttention does far fewer round-trips to slow memory.
The tiling approach, and why it was hard
The natural response to "process attention in blocks that fit in SRAM" runs into a mathematical problem: softmax.
The softmax operation normalizes each attention row so scores sum to 1:
softmax(x)_i = exp(x_i) / sum_j(exp(x_j))
To compute the denominator, you need the entire row. For a sequence of length N, that's N values — they have to be in memory simultaneously. If you split K and V into blocks and process them one at a time, you can't complete the softmax for any query until you've seen every key.
This is why naïve tiling doesn't work. You need all of K to normalize the softmax for each query, so you have to load K entirely from HBM. The standard algorithm is "natural" precisely because the math seems to require it.
FlashAttention uses the online softmax algorithm (Milakov and Gimelshein 2018, adapted) to break this dependency. The insight: you can compute softmax incrementally, processing one block of keys at a time, if you maintain running statistics:
- A running maximum
mof scores seen so far - A running sum
lof exponentiated and scaled scores
When you process a new block of keys, you update both statistics and rescale the partial output accordingly. The final output is mathematically identical to standard softmax — no approximation — but you never need the full N×N matrix in memory. You read each block of K and V once, compute the contribution to Q, update the running statistics, and move on.
The paper shows the IO complexity for this approach:
HBM accesses = O(N² d / M)
where d is head dimension. For typical transformer configurations (d=64, N=2048, M≈20MB), this is roughly 9× fewer HBM reads/writes than standard attention.
What the performance numbers actually show
The speedup depends heavily on sequence length because the memory bottleneck worsens quadratically with N.
For short sequences, standard attention is barely memory-bound — most of the N×N matrix fits comfortably in cache, and the kernel overhead of FlashAttention partially offsets the benefits. For BERT-large at 512 tokens, the paper reports 15% end-to-end training speedup — real but modest.
For longer sequences, the gap widens:
- GPT-2 at 1K tokens: ~3× speedup over standard attention
- Long-range arena at 1K–4K tokens: ~2.4× speedup on tasks where standard attention was the bottleneck
The more significant impact is on what becomes feasible, not just what becomes faster. The paper reports that FlashAttention enables training transformers on 16K-token sequences (Path-X benchmark) where standard attention would OOM. The model achieves 61.4% accuracy — the first transformer to exceed 50% on that task. Not because the model architecture changed, but because you could finally train it.
FlashAttention 2 (Dao 2023) extends this with algorithmic changes targeting GPU utilization: reduced non-matmul operations, better work partitioning across thread blocks, and improved warp-level scheduling. The result is ~2× faster than FA-1, reaching 225 TFLOPs/s on an A100 and 72% model FLOPs utilization — up from 25–40% in FA-1. The bottleneck shifts from memory bandwidth to compute, which is where you want it.
Production tradeoffs no one mentions in the benchmark post
It doesn't help your optimizer state. FlashAttention reduces activation memory during the forward and backward passes. It doesn't touch parameter memory, gradient memory, or optimizer state. For training runs bottlenecked by Adam's 2× parameter memory in optimizer state, FlashAttention alone won't solve the OOM. Teams that have spent memory budget on 8-bit Adam see less marginal benefit from FA.
The kernel is hardware-specific. FlashAttention's implementation is a custom CUDA kernel tightly tuned for SRAM sizes and bandwidth ratios on specific GPU generations (A100, H100, A6000). On consumer-grade cards (3090, 4090), the speedup is smaller because the SRAM/HBM ratio and bandwidth figures are different. On non-NVIDIA hardware — Apple Silicon, AMD, TPUs — you're relying on third-party ports with varying maturity. Before assuming FA gives you the paper's speedups, benchmark on your actual hardware.
Attention visualization breaks. Standard attention explicitly materializes the N×N score matrix. Tools that read attention weights — visualization dashboards, mechanistic interpretability analysis, custom masking operations that depend on attention patterns — need that matrix. FlashAttention never writes it to HBM; the intermediate scores live temporarily in SRAM and are discarded. If you have any tooling that reads attention weights, switching to FlashAttention silently breaks it. You'll get incorrect attention heatmaps, or zeros, depending on how your visualization hooks into the model.
Custom attention patterns require custom kernels. Sliding window attention, prefix-suffix asymmetric masking, RoPE with non-standard positions, attention sink patterns — any modification to the standard QKᵀV computation requires re-implementing the tiling logic. The stock FlashAttention kernel handles causal masking and standard full attention. Everything else needs a new kernel. Teams that start with FA and then need to add a custom attention modification for their architecture discover this when they find that the FA source is about 2000 lines of CUDA.
The memory savings aren't always compound. If you're running multi-head attention with 32 heads at 128 head dimension on a 4K sequence, FA's memory savings are substantial per layer. But if your bottleneck is the feed-forward network (which is often larger than attention in modern architectures), or the KV cache at serving time, FA addresses only part of the memory budget.
Failure modes in practice
The most common production failure mode I've seen: silent precision degradation on fine-tuning.
FA-1 only supported FP16 and BF16 — no FP32 accumulation. FP16 has limited dynamic range; when attention logits are large (long sequences produce large QKᵀ values before scaling), you get overflow. The standard numerical fix is to subtract the row maximum before exponentiating — which FA-1 does — but the intermediate representation is still FP16, and the rounding errors accumulate differently than in FP32 standard attention. For inference from a pretrained model, this rarely matters. For fine-tuning, you're computing gradients through these operations, and the rounding behavior influences gradient signal. FA-2 improves this with better accumulation, but teams that switched to FA-1 mid-fine-tuning sometimes saw training diverge in ways that were hard to attribute.
The second failure mode is the fallback problem. Most frameworks (Hugging Face Transformers, vLLM, TGI) have FA support behind a flag. When FA isn't available — unsupported hardware, unsupported attention variant, wrong dtype — they fall back to standard attention. The code runs. The benchmark looks fine. Memory usage doubles, throughput drops by 30%, and the only signal is that your GPU utilization metrics changed. If you're not explicitly verifying which attention implementation is executing, you can spend a lot of time debugging "performance regression" that's actually just FA not being used.
When not to use FlashAttention
Short sequences (<512 tokens). The memory overhead of standard attention is small, the kernel launch overhead of FA is proportionally larger, and the speedup doesn't materialize. For chatbot use cases with short prompts and short completions, FA's benefit is marginal. Run the benchmark.
When you need to read attention weights. Any interpretability work, attention visualization, custom masking logic, or debugging workflow that depends on the N×N matrix is incompatible with FA's design. You can run standard attention for debug passes and FA for production inference, but keeping that separation consistent across a codebase is operationally annoying.
When your architecture deviates from standard scaled dot-product attention. Cross-attention between mismatched sequence lengths, attention with learned sparsity masks, multi-query attention with unusual key-sharing patterns — these may or may not have FA support depending on the version and implementation. Check before assuming.
When you can't verify the kernel is running. If you don't have a reliable way to instrument which attention kernel executes at runtime, the silent fallback problem will bite you. Either add an assertion that FA is active, or treat the memory and throughput as uncertain. The framework flag saying use_flash_attention=True doesn't guarantee it runs on all inputs.
What the paper actually gives you
FlashAttention is a demonstration that a lot of the apparent limitations in transformer scaling were implementation artifacts, not architectural constraints. The N×N memory requirement for attention was treated as fundamental for years. It isn't — it follows from a specific sequence of operations that were never required by the math.
The IO-aware framing is the generalization worth carrying. GPU operations that seem memory-bound aren't always fundamentally so; sometimes the kernel is just doing unnecessary round-trips through slow memory. Layernorm, dropout, and the residual stream all have fused kernel implementations that follow the same logic: keep intermediate values in SRAM, avoid HBM writes for temporaries. FA demonstrated that this approach could recover significant performance even in operations that look inherently sequential.
For your specific situation: if you're running inference or training at sequence lengths above 2K tokens, FA-2 is the default choice and the baseline against which you should measure everything else. If you're at short contexts, measure first. And if you're instrumenting attention patterns for any reason, understand that you're no longer running the FA kernel on those paths — verify that the fallback behavior is what you expect.
The 8K-context OOM that started this post? We switched to FA-2, added an assertion that checked torch.backends.cuda.flash_sdp_enabled() at startup, and the memory budget dropped by roughly 40% at that sequence length. Inference now fits without context truncation. The model outputs are identical to standard attention — not approximately identical. Exactly identical, by construction.
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness — Dao, Fu, Ermon, Rudra, Ré. NeurIPS 2022. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning — Dao. ICLR 2024.