Skip to content

XSA-All 11L + LeakyReLU(0.75)² + Aggressive Legal TTT → 1.1219 BPB#1092

Open
teddyoweh wants to merge 3 commits intoopenai:mainfrom
teddyoweh:submission/xsa11-leakyrelu075-legalttt
Open

XSA-All 11L + LeakyReLU(0.75)² + Aggressive Legal TTT → 1.1219 BPB#1092
teddyoweh wants to merge 3 commits intoopenai:mainfrom
teddyoweh:submission/xsa11-leakyrelu075-legalttt

Conversation

@teddyoweh
Copy link
Copy Markdown

Results

val_bpb: 1.1219 | Artifact: 15,916,230 bytes (15.92 MB) | 8×H100 SXM

Seed step_avg steps Pre-TTT bpb Post-TTT bpb TTT gain TTT time Artifact
1337 93.97ms 6,173 1.1252 1.1219 -0.0033 464s 15,916,230

What's New

Three independently validated improvements on top of the PR #414 + PR #399 stack:

1. XSA on All 11 Layers (XSA_LAST_N=11)

Extending eXtended Self-Attention from last 4 layers to all 11 yields -0.0007 BPB. The richer attention outweighs ~4% slower step time (93.97ms vs ~90ms).

2. LeakyReLU(0.75)²

Higher negative slope than the current SOTA (0.75 vs 0.5). From PR #977's ablation, 0.75 is strictly better than 0.5 for the int6 stack. Preserves more gradient flow through the MLP.

x = F.leaky_relu(self.fc(x), negative_slope=0.75).square()

3. Aggressive Legal TTT (lr=0.03)

Score-first TTT using PR #461's legal framework with a 15× higher learning rate (0.03 vs 0.002). Delivers -0.0033 BPB improvement (vs -0.0025 in SOTA). All blocks unfrozen, SGD with momentum 0.9, 3 epochs per chunk, cosine LR decay.

torch.inference_mode() guarantees scoring is stateless — weights are only updated AFTER the chunk is scored.

FA3 Fallback

Script includes automatic fallback from Flash Attention 3 to PyTorch SDPA:

try:
    from flash_attn_interface import flash_attn_func as flash_attn_3_func
    _HAS_FA3 = True
except ImportError:
    _HAS_FA3 = False

Our run used SDPA (93.97ms/step → 6,173 steps). With FA3 (~84ms/step → ~7,100 steps), expected BPB would be in the 1.119x range.

Timing

Phase Time
Training 580s
Eval (Legal TTT sliding) 464s
Total < 20 min

Run Command

BIGRAM_VOCAB_SIZE=2048 TRIGRAM_VOCAB_SIZE=0 \
XSA_LAST_N=11 \
EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=1 SWA_EVERY=50 \
ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 LATE_QAT_THRESHOLD=0.15 \
VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 \
TTT_ENABLED=1 TTT_LR=0.03 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=32768 \
TTT_FREEZE_BLOCKS=0 TTT_MOMENTUM=0.9 TTT_BATCH_SEQS=32 TTT_GRAD_CLIP=1.0 \
MUON_WD=0.04 ADAM_WD=0.04 \
MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \
MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3500 \
ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=580 EVAL_STRIDE=64 \
SEED=1337 \
torchrun --standalone --nproc_per_node=8 train_gpt.py

Credits

@Christopher-Lee-McClendon
Copy link
Copy Markdown

Excellent combination of tweaks that synergize with more aggressive TTT. I'm surprised that the 15x learning rate was better, nice finding!

@MatoTeziTanka
Copy link
Copy Markdown

Community Review — XSA-All 11L + LeakyReLU(0.75)² + Aggressive Legal TTT → 1.1219 BPB

BPB: 1.1219 | Compliance: LOOKS CLEAN — score-first-per-chunk TTT (legal #1416/#1423 pattern)

What I found in the code (head SHA 86f53b217c7d, file records/track_10min_16mb/2026-03-29_XSA11_LeakyReLU075_LegalTTT/train_gpt.py):

The TTT path at line 1133 implements the score-first-per-chunk pattern: each chunk is scored under torch.no_grad() / inference_mode() before the base_model.train() + SGD adaptation runs on that same chunk, with an is_last_chunk guard so the final chunk gets no adaptation pass. This is the structural shape the legal frontier uses (PRs #1416 erichroepke, #1423 aryanbhosale).

Per Issue #402 and Issue #677, TTT is legal when each token is scored before the adapter updates on it, and that's what the code does here — chunk ci is scored under weights adapted only on chunks 0..ci-1. No prequant_ttt_adapt_adamw(val_tokens, ...) multi-epoch fine-tune, no scored-region SLOT, no target-in-key n-gram cache.

CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 0.04s, dim=512, layers=11, vocab=1024, code=94098 B, SMOKE_TEST_PASS

Verdict: LOOKS CLEAN.

Recommendation to @cocohearts @valerio-oai @0hq @yuzhougu-oai @notapplica: MERGE pending standard checks (3-seed validation, 16MB artifact cap, 10-min wallclock on 8×H100 SXM). The compliance picture matches the legal reference frontier and no flags were raised by the classification pass.

Auto-classification caveat: this review was drafted by the AST-based classifier against a template derived from manually-reviewed cluster PRs (#1420, #1450, #1487, #1541, #1529, #1533, #1518). If I've misread a subtlety in your eval path — e.g., multi-epoch TTT that I mistook for single-pass, or a target-in-key lookup I missed in a helper function — please flag it and I'll re-run the audit manually.


Reviewed by @MatoTeziTankaThe Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 0.04s, dim=512, layers=11, vocab=1024, code=94098 B, SMOKE_TEST_PASS. Classification via deterministic AST-based classify_prs.py (pattern bank derived from ~65 manually-reviewed PRs earlier in the 2026-04-11 sweep). This review was auto-drafted from a template and spot-checked before posting — if the template misread your code, please call it out so I can iterate the classifier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants