Record: varlen+fused mlp+ttt (bpb=1.1093)#1354
Conversation
8315ff3 to
90ee308
Compare
Novel mechanism: zero-initialized nn.Embedding(4096, 512) created at eval time, trained exclusively through the standard score-first TTT loop. Learns document-local bigram patterns without modifying any artifact weights. Hash: h = (prev_token * 2039 + curr_token) % 4096 Injection: tok_emb(x) + eval_hash_emb(h), before RMSNorm Compliance: same score-first pattern as openai#549/openai#1413 TTT precedent. Precedent for eval-time params: LoRA-TTT (openai#1254, openai#1354). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Community Review — Record: varlen+fused mlp+ttt (bpb=1.1093)Compliance: NEEDS AUTHOR ACTION — What I found: The CPU smoke test on CT2038 (proteus-engine, 128 GB RAM, Triton 3.6.0, flash_attn stub, cutlass_evt_fusion stub) failed at the import step with: A few of the common patterns I've seen for this class of error in the 2026-04-11 sweep:
Recommendation: Could you run Once the parse/import issue is fixed, I'll re-run the compliance audit through the normal pipeline. No other flags identified yet because the audit halts at the import step. Reviewed by @MatoTeziTanka — The Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): IMPORT_FAIL — ImportError: cannot import name 'flash_attn_varlen_func' from 'flash_attn_interface' (unknown location). Classification via |
|
Closing because updated PR (bpb=1.07336) is at #1530 |
Record: Varlen attention + fused MLP + TTT
val_loss: 1.8729 | val_bpb: 1.1093 | ~15.9 MB | 8×H100 SXM, 600s train + ~500s TTT eval
Increased training speed ~3% via variable length attention and a fused MLP kernel, yielding a ~0.002 nat improvements. Re-added an optimized document-based LoRA TTT that yields a ~0.007 nat improvement. Together, these 3 improve performance ~0.009 nats.
Based on a hackathon last weekend with @aldopareja, @sestinj, and @chrishamblin7 :)
Note: it is very hard to tell what sota is because there are (at time of this PR) 905 open PRs, 279 of which are tagged as records, and many of these PRs cheat in some way. Afaict, the three improvements in this PR are orthogonal to any recent PR, so once another record is merged, I will add these changes.
Main changes
Improves upon record 2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072 with 3 things:
1. Variable length attention (~1% faster training, ~0.001 nats)
Replaced dense causal attention with Flash Attention 3's
flash_attn_varlen_func. During training, documents are packed into flat token buffers withcu_seqlensboundaries so attention is computed within documents only — the model never attends across unrelated documents that happen to be adjacent in a batch.This does two things:
100 * 100**2 = 1Mattention FLOPs vs10 * 1000**2 = 10Mwith dense attention. This leads to ~1% faster training on 8xH100, and the additional training steps buy ~0.001 nats improvement. This improvement is limited because the model is so small that there is a lot of overhead which is not in the attention so it can only be sped up so much.2. Fused MLP (~1% faster training, ~0.001 nats)
A custom Triton kernel (
linear_leaky_relu_square_kernel) fuses the up-projection, LeakyReLU(0.5)² activation, and squaring into a single kernel. Based on similar kernels from modded-nanogpt. Also ~1% faster training on 8xH100, yielding another ~0.001 nats improvement.3. Test-time training (TTT) (~0.007 nats)
Re-adds LoRA-based TTT, based on my old implementation which buys ~0.007 nats. This is an instance of "Case 3" according to this classification. In the previous record, TTT had a mismatch with train: sequences did not attend to other sequences during TTT/eval but did during training. Here, now that we are not attending to other sequences during training, this is avoided (> 70% of training sequences in the old code attended to previous sequences), compared to ~0% now). The bigger improvement comes from using a smaller chunk size (32 instead of 256, so more gradient updates per sequence) and using RMSProp instead of Adam (just set
beta1=0for Adam during TTT). To be able to use a smaller chunk size, I had to substantially optimize the TTT speed via vectorizing ops and better batch scheduling system, ~2x'ing the old implementation's speed.TTT analysis
As you can see below, TTT helps the most at later positions in the document (note this is run on 1/10th of the validation set). The top plot is showing, at each position x in a sequence:
Even though only ~5% of the tokens in the dataset are at postion 10k or later in their sequence, > 15% of the loss improvement from TTT comes from those later positions. In the bottom plot, you can see that if we only cared about long-context performance and only looked at positions 10k and up, the gain from TTT would be much greater than 0.01 nats!
Other small changes and notes
Run results
sam:~/parameter-golf# python records/track_10min_16mb/2026-04-04_VarLenAttn/calc_p.py \ --logs records/track_10min_16mb/2026-04-04_VarLenAttn/seed1-eval.txt \ records/track_10min_16mb/2026-04-04_VarLenAttn/seed2-eval.txt \ records/track_10min_16mb/2026-04-04_VarLenAttn/seed1337-total.txt baseline val_loss: [1.8828 1.8816 1.8822] mean=1.88220 std=0.000490 new val_loss: [1.87397897 1.87277334 1.87202097] mean=1.872924 std=0.000806 delta (baseline - new): 0.009276 baseline val_bpb: [1.1151 1.1144 1.1148] mean=1.114767 std=0.000287 new val_bpb: [1.1098759 1.10916186 1.10871626] mean=1.109251 std=0.000478 delta (baseline - new): 0.005515 val delta loss threshold: 0.005 p-value (new is ≥0.005 below baseline): 0.002870Also note that the logs for this run are 5 files, not 3. For seeds 1 and 2, I ran training before implementing/tuning TTT, so to save compute I did not re-run training, but just loaded the checkpoint. For clarity, I will re-run with the final code hopefully later today.
Replicating runs + dev