Titans: what the test-time memorization paper actually says
Reading Behrouz, Zhong, and Mirrokni (2024) while watching a 200K-context retrieval task silently degrade at 180K tokens.
Paper: "Titans: Learning to Memorize at Test Time" — Behrouz, Zhong, Mirrokni (Google Research, 2024).
The failure mode was subtle. We were processing long legal documents — contracts with exhibits and amendments, often 150–200 pages of dense text — through a transformer-based pipeline. Retrieval accuracy was fine in testing. In production, on the longest documents, it collapsed. The model would correctly identify provisions in the first 50 pages and completely miss identical provisions on page 180. The context fit. The model had attended to the right tokens at training time. But somewhere in the forward pass through a 200K sequence, the signal from early tokens had been diluted to noise.
This is the problem Titans is solving, and it's worth being precise about what the problem actually is before getting into the solution.
Two constraints that bound every production deployment
Current neural sequence models live under one of two regimes.
Transformers scale their memory with the sequence: every token in context is accessible via attention, at O(N²) cost. This is expensive but exact — any token can contribute to any output, regardless of where it appears. The practical ceiling for standard attention is around 128K–256K tokens on modern hardware before you hit either memory limits or latency limits that make the system unusable. Techniques like Flash Attention and Ring Attention push this boundary, but don't eliminate it.
Recurrent models (SSMs like Mamba, linear attention variants, LSTMs) scale linearly and maintain a fixed-size hidden state. At inference, each new token is processed in constant time and memory, regardless of sequence history. The catch: all historical information must be compressed into that fixed state. As the sequence grows, older information gets overwritten or diluted. The state size is set at training time and doesn't grow. You can't expand it at runtime to accommodate an unexpectedly information-dense input.
The Titans paper — "Titans: Learning to Memorize at Test Time," Behrouz, Zhong, and Mirrokni, Google, arXiv 2501.00663 — proposes a third option: a neural long-term memory module that is explicitly trained to compress historical context into its parameters, and that continues learning — updating those parameters — during inference.
This is an unusual proposal. Most neural network inference is read-only with respect to model weights. Titans makes the memory module's weights writable at inference time, and uses gradient descent to write them.
The neural memory module
The memory module is a small MLP. Its "memory" isn't stored in activations or a KV cache — it's stored in its weights. To retrieve information, you do a forward pass. To write information, you do a gradient update.
The update rule at time step t is:
M_t = (1 - α_t) M_{t-1} + S_t
S_t = η_t S_{t-1} - θ_t ∇ℓ(M_{t-1}; x_t)
Breaking this down:
- M_t is the memory module's weights at step t. This is what's being updated.
- α_t ∈ [0,1] is a forget gate. When α_t → 0, old memory is preserved. When α_t → 1, memory is cleared. This is learned and input-dependent.
- S_t is a momentum term — gradient updates are smoothed over time rather than applied raw, similar to SGT momentum. η_t controls momentum decay.
∇ℓ(M_{t-1}; x_t)is the gradient of the memory module's prediction loss on input x_t. This is the surprise signal: how much does the current input deviate from what the memory module predicted?
The "surprise" framing is important. The gradient is large when the memory module's current state makes a bad prediction on x_t — meaning x_t contains information the memory doesn't already have. The update is large when information is new, small when it's redundant. This is a form of learned importance weighting: inputs that are surprising get written to memory more strongly.
The forget gate α_t is equally important. Without forgetting, the memory accumulates all historical information with equal weight. With an adaptive forget gate, the model can learn to clear memory when context shifts — a new document, a topic change — and preserve memory when information should persist. The paper shows this generalizes the forgetting mechanisms in Mamba and LRU as special cases.
Three ways to integrate memory into an architecture
The paper defines three architectural variants — MAC, MAG, and MAL — that differ in where and how the memory module integrates with attention.
MAC: Memory as Context
In MAC, the memory module is queried at the start of each step, and the retrieved information is concatenated with the current input before attention:
h_t = M*_{t-1}(q_t) // retrieve from memory using current query
o_t = Attn([h_t, x_t]) // attend over retrieved + current input
M_t = update(M_{t-1}, x_t) // update memory after processing
This is the closest analog to how people think about RAG: query memory first, augment with result, then process. The attention window handles recent tokens precisely; the memory module provides compressed access to history. The two sources don't need to agree — attention on recent context can override stale memory retrievals.
MAG: Memory as Gating
In MAG, memory and attention run in parallel branches. The memory module processes the sequence independently, and the two outputs are merged via a non-linear gate:
y_attn = SlideWindowAttn(x)
y_mem = M(x̃)
o = y_attn ⊗ y_mem // element-wise gated merge
The gate is itself learned, allowing the model to weight attention heavily when recency matters and weight memory heavily when it doesn't. This is architecturally similar to how Mamba uses a gating mechanism to mix the SSM output with the residual stream.
MAL: Memory as Layer
MAL is the simplest: memory and attention are stacked sequentially. Memory processes first, attention processes the memory output:
y = M(x̃)
o = SlideWindowAttn(y)
This is a straightforward layered composition. The memory module learns to present its compressed view of history; attention then applies precise in-context reasoning on top of that representation.
Among the three, MAC and MAG consistently outperform MAL in the paper's experiments. The authors' interpretation is that sequential stacking (MAL) creates an information bottleneck: attention can only access what the memory layer chooses to surface. Parallel integration (MAC, MAG) lets attention use both raw current input and memory-retrieved history, preserving more information.
What the benchmark numbers actually say
The language modeling results compare 760M-parameter models trained on 30 billion tokens. Perplexity on Wikipedia:
| Model | Perplexity | |---|---| | Transformer++ | 25.21 | | Mamba-2 | 22.94 | | Titans (MAC) | 19.93 | | Titans (MAG) | 18.61 |
This is a substantial gap, especially between Mamba-2 and Titans MAG. Perplexity is a log-scale metric, so the difference is larger than the raw numbers suggest — Titans MAG is modeling the token distribution significantly better at this scale.
On downstream tasks — averaged accuracy across PIQA, HellaSwag, Winograd, ARC, SIQA, and BoolQ — Titans MAC hits 52.51% and MAG hits 52.50% vs. Transformer++ at 48.69% and Mamba-2 at approximately 51%. These are zero-shot evals at 760M scale, so the absolute numbers are modest, but the improvement over both the transformer baseline and the SSM baseline is consistent.
The most striking result is on BABILong, a benchmark designed to test long-document reasoning by embedding facts at known positions in very long documents ("needle in a haystack" at scale). Titans beats GPT-4 and Llama3.1-70B on this benchmark after fine-tuning. The comparison isn't apples-to-apples — different model sizes and training setups — but it validates the core claim: the neural memory module enables effective retrieval from contexts that break standard attention and swamp SSM hidden states.
The paper also reports scaling to 2M+ context windows, with accuracy holding where Transformer++ and Mamba degrade significantly. They don't show this with production-scale models (the experiments are at smaller scale), but the principle is demonstrated.
What actually happens at inference time
The part of Titans that gets underplayed in coverage is what "test-time learning" actually costs.
During inference, for each input token, the system:
- Queries the memory module with a forward pass
- Computes the memory update gradient via backpropagation through the memory module
- Updates the memory module's weights with the gradient step
- Proceeds with the main model forward pass
Step 2 and 3 are gradient descent — not just a forward pass. This means inference requires running backprop through the memory module on every token. The memory MLP is kept small to make this tractable, but it's a qualitatively different inference profile from standard models. You're doing mini-training during inference.
The paper addresses inference efficiency by noting that the memory module parameters can be updated in parallel across a sequence during training (using an associative scan), and that the update equations are linear enough to parallelize. But the wall-clock inference cost is higher than a standard transformer at the same parameter count. The paper doesn't give exact latency numbers, which is a gap worth noting.
For deployment, this means:
- Inference is more memory-write-intensive — GPUs need to sustain weight update bandwidth alongside activation bandwidth
- The memory state is mutable — serving systems that share model weights across requests need to manage per-request memory state separately, similar to how KV caches are per-request
- The memory module's state needs to be saved and restored for multi-turn or stateful sessions, adding to checkpoint/state management complexity
None of these are blocking issues, but they require awareness.
When NOT to use Titans
Short or medium context. Below ~32K tokens, standard attention with Flash Attention is faster, well-understood, and has extensive optimization work behind it. The memory module overhead isn't justified when your context comfortably fits in an attention window.
When you're quantizing aggressively. The memory update step involves gradients, which require higher numerical precision than standard inference. INT4 or INT8 quantization of the memory module may destabilize the update rule in ways that aren't obvious from perplexity alone. The paper evaluates in FP16/BF16. Be cautious here.
Stateless serving. If your inference infrastructure is designed around stateless request handlers that load model weights from shared storage, the mutable per-request memory state is a significant architectural mismatch. You'd need to serialize, store, and restore the memory module weights per session, which is operationally complex.
When you need exact reproducibility. Because the memory module updates depend on the exact sequence of inputs processed, the same prompt + context may produce different outputs if the memory state coming in differs. This is expected behavior, but it means standard prompt-level caching (including prefix caching) doesn't apply without modification.
When you're not actually memory-limited. The Titans benefit is specifically for long-context retrieval and reasoning over information that exceeds what attention handles well. If your actual bottleneck is inference throughput on short sequences, reasoning quality on a fixed context, or training stability — Titans doesn't help and adds complexity.
What this changes in practice
The paper is from December 2024. As of writing, there's no production-grade Titans implementation equivalent to vLLM for PagedAttention or the Mamba reference implementations. The architectural ideas are sound and the benchmark results are reproducible, but the deployment story is immature.
What Titans gets right architecturally is the decomposition: attention handles short-term precision, neural memory handles long-term compression. This decomposition is stable because they're doing different things with different tradeoffs. The risk in earlier hybrid architectures (linear + attention) was that the components were doing the same thing less well. Titans makes the memory module distinctly responsible for compression over long horizons, which gives the attention component clean semantics.
The gradient-as-write-mechanism is clever and theoretically grounded — it's equivalent to fitting a small neural network to a compressed representation of the sequence, with surprise as the learning signal. Whether this outperforms learned associative memory approaches (like Hopfield Networks or attention-based memory retrieval) at production scale is an open question the paper doesn't fully answer.
For teams building systems where long-context degradation is a real and measured problem — not a hypothetical — Titans is worth serious evaluation. The perplexity gap over Mamba-2 at 760M scale is large enough to be practically meaningful. The open question is whether the inference cost, when properly measured on production hardware, is justifiable given what you'd get from simply buying a larger attention-based model. That math doesn't exist in the paper yet, and it needs to before Titans becomes a default recommendation.
Paper: "Titans: Learning to Memorize at Test Time," Ali Behrouz, Peilin Zhong, Vahab Mirrokni. arXiv:2501.00663, December 2024.