Record: Varlen attention + fused MLP + doc-independent TTT (1.07336)#1530
Record: Varlen attention + fused MLP + doc-independent TTT (1.07336)#1530samacqua wants to merge 3 commits intoopenai:mainfrom
Conversation
…g + Muon 0.97 — val_bpb 1.07747 (3-seed mean) - 3-seed mean: 1.07747 BPP (std 0.00064) / 2.78321 nats - ~15.99 MB artifact, 8×H100 SXM, 600s - VarLen attention (within-document only), doc-independent LoRA TTT - Parameter banking + triple depth recurrence + parallel residuals - PyTorch MLP fallback (no Triton/CUTLASS dependency) - Based on PR openai#1530, PR openai#1523, PR openai#1514
|
I may be missing something, but I think there is one higher-scrutiny In the TTT path, the compile warmup appears to use actual validation tokens before the main eval loop, and it also does My read of the current guidance is:
If that reading is right, would you be willing to switch the warmup to:
That would make the legality story much cleaner. |
|
@dexhunter it could honestly just be commented out, given that warmup + eval time is still < 600s. But it shouldn't matter -- training warmup does the same thing, parameters and optimizer states are reset. As a sanity check I re-ran TTT on seed 2 w/ warmup commented out, and the loss was within expected variance between runs (actually did slightly better): But given that making a change + re-running what take an hour of 8xh100, I will only if it is a blocker. |
Community Review — VarLen attention + fused MLP + doc-independent TTTThanks @samacqua. Doc-independent TTT via cu_seqlens boundary isolation is a genuinely interesting approach to the causal-dependence question the SLOT cluster has been bouncing around. One import blocker, then a deeper question on the doc-independence claim. What I found (head SHA
"Doc-independent TTT" — the interesting idea. My read is that if the LoRA (or whatever TTT-like adaptation you're running) respects the same The open question is whether the adaptation state itself is per-document or per-batch. I couldn't find an Import blocker (smoke test). The CPU smoke on CT2038 hit: My flash_attn stub covers Questions
Compliance summary (partial)
Verdict: LOOKS INTERESTING, NEEDS AUTHOR CLARIFICATION on the TTT adaptation path. Recommendation to @cocohearts @valerio-oai @0hq @yuzhougu-oai @notapplica: HOLD pending author clarification on where and how the doc-independent TTT runs. If the adaptation respects cu_seqlens boundaries and the temporal ordering is score-before-adapt at the document level, this is a genuinely clean path out of the SLOT compliance bind, and I'd flip to MERGE. Reviewed by @MatoTeziTanka — The Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): IMPORT_FAIL due to |
|
@MatoTeziTanka look at
So yes, it respects the same document boundaries. It is strictly harder (and more valid imo) than TTT on the full validation sequence autoregressively. See the "Methods" section of this blog for clarity. |
… (3-seed mean) PR openai#1530 v2 base + warmdown_frac=0.75 + TTT_CHUNK_SIZE=48 + Muon 0.97. 3-seed mean: 1.07406 (std 0.00132), 2.77441 nats. Delta vs merged SOTA (openai#1493): -0.01491 nats (clears 0.005 bar by 3.0x). All artifacts < 16 MB, train < 600s, eval < 225s.
|
The pattern structurally matches what @valerio-oai flagged as invalid in #677 ("adapt on validation before the reported eval pass"). Even though LoRA resets per batch, the compile warmup still runs backward+step on val tokens before the eval loop. Since you confirmed the fix is within variance, it would be worth switching to random/synthetic tokens to avoid any ambiguity during review. |
This submission actually looks good to me. They don't "adapt on validation before the reported eval pass", as the warmup/compilation throws away the updates. The final result wouldn't change at all if they replaced those validation tokens in warmup with any other tokens. The author even notes that the result is unchanged, even when they comment out the warmup. |
|
Hey btw @samacqua the training script will crash without the |
Record: Varlen attention + fused MLP + TTT
val_loss: 2.77261 | val_bpb: 1.07336 | ~15.99 MB | 8×H100 SXM, 587s train + ~340s TTT eval
Best PR bpb (PR #1529): bpb=1.0753 (delta=0.0019), loss=2.7776 (delta=0.0050)
Merged record bpb (PR #1493): bpb=1.0810 (delta=0.0076), loss=2.7923 (delta=0.0197)
Increased training speed ~5% via variable length attention, a fused MLP triton kernel (no
cutlass_evt_fusiondep), and grouping together small parameters, yielding ~.002 nats when comparing sliding window eval. Re-added document-based LoRA TTT which has no inter-sequence dependence and improves over strided evaluation by ~.008 nats.Main changes
Applied changes from my old PR to a recent record PR: #1523. But PR #1552 beat my previous bpb before I submitted the PR, so I incorporated their (orthogonal) improvements. Most of below is copied from my previous PR #1354.
This involves 3 things:
1. Variable length attention (~2% faster training, ~0.001 nats)
Replaced dense causal attention with Flash Attention 3's
flash_attn_varlen_func. During training, documents are packed into flat token buffers withcu_seqlensboundaries so attention is computed within documents only — the model never attends across unrelated documents that happen to be adjacent in a batch.This does two things:
100 * 100**2 = 1Mattention FLOPs vs10 * 1000**2 = 10Mwith dense attention.2. Fused MLP + grouped small params (~3% faster training, ~0.001 nats)
A custom Triton kernel (
linear_leaky_relu_square_kernel) fuses the up-projection, LeakyReLU(0.5)² activation, and squaring into a single kernel. Based on similar kernels from modded-nanogpt. I also group the many tiny replicated scalar/control gradients into a single all-reduce to avoid a pile of tiny collectives.3. Doc-based test-time training (TTT) (~0.008 nats over sliding window)
Although it is technically legal in this competition to train on tokens from previous documents in the dataset, I am spiritually opposed to this. Under the current formulation, if the eval set was bigger, the expectation of the loss would be lower which seems broken. So in this implementation, there is score-first TTT applied to each sequence in the validation set independently (and efficiently using batched LoRAs), which is strictly harder.
Re-adds LoRA-based TTT, based on my old implementation, but > 2x faster which allows for using smaller chunk sizes which leads to better performance. This is an instance of "Case 3" according to this classification.
It's interesting to note that adding test-time training improves loss more than adding ~215 steps. These 215 steps train on
786432*215=169,082,880tokens to gain ~.002 nats. The average sequence length in the validation set is ~200 tokens which means test-time training here gains ~.003 nats / 800 tokens on average (valid bc sequences are trained independently). So, in a way, TTT is~(.003/800) / (.002/169082880) >= 300ktimes more token efficient than pre-training: it helps to be in distribution :)Other small changes
Made some changes to make replication and dev based on this PR easier:
Replicating runs + dev