Skip to content

Commit a8e8027

Browse files
committed
Submission: Learned Routing + Two-Pass N-gram Rescoring + Extended Orders
Combines PR openai#834's learned multi-expert routing head with PR openai#846's two-pass cold-cache rescoring. Key changes: - Extended n-gram orders from 2-7 to 2-12 with 8M bucket hash tables - Two-pass eval: rescore first 15 chunks with full cache after pass 1 - Per-chunk loss tracking for precise pass-1/pass-2 delta computation - Configurable via env vars: NGRAM_MAX_ORDER, NGRAM_BUCKETS, TWO_PASS_ENABLED, TWO_PASS_RESCORE_CHUNKS Based on PR openai#834 (AnirudhRahul) + PR openai#846 (himanshudongre) stack.
1 parent 50390d6 commit a8e8027

File tree

4 files changed

+2042
-0
lines changed

4 files changed

+2042
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Learned Routing + Two-Pass N-gram Rescoring + Extended Orders
2+
3+
**val_bpb: TBD** | **~15.9 MB** | 8xH100 SXM
4+
5+
## Key Innovation: Combining Learned Routing with Two-Pass Rescoring
6+
7+
PR #834 introduced a learned `Linear(512->7)` routing head trained end-to-end against the mixer objective, but uses single-pass eval. PR #846 introduced two-pass rescoring (rescore cold-cache early chunks with the full cache) but uses a heuristic entropy-sigmoid alpha. This submission combines both: a learned routing head with two-pass rescoring, plus extended n-gram orders (2-12) and larger hash tables.
8+
9+
## Techniques
10+
11+
### Learned Multi-Expert Routing Head
12+
- `Linear(512, 12)` head reads transformer hidden state
13+
- Routes between 1 neural expert + 11 n-gram orders (2-12)
14+
- Trained end-to-end with frozen n-gram oracle during training
15+
- Masked softmax: invalid orders (insufficient context) masked to -inf
16+
- Neural floor: 5% minimum weight on neural expert
17+
18+
### Two-Pass N-gram Rescoring
19+
- Pass 1: Standard sequential chunk eval with causal cache building
20+
- Pass 2: Rescore first 15 chunks with the full cache (no cache updates)
21+
- Early chunks improve dramatically (from ~1.15 BPB to ~0.12 BPB)
22+
- Adds ~50-60s to eval time
23+
24+
### Extended N-gram Orders (2-12)
25+
- 11 n-gram expert orders vs 6 (PR #834) or 8 (PR #846)
26+
- 8M bucket hash tables (vs 1M or 4M) for fewer collisions
27+
- Per-order min_count thresholds
28+
29+
### TTT -> N-gram Pipeline
30+
- TTT adapts model weights on already-scored chunks
31+
- N-gram eval uses TTT-adapted weights (not base model)
32+
- Better neural expert contribution in the mixture
33+
34+
## Architecture
35+
36+
PR #834/414 stack:
37+
- 11 layers, 512d, 8H, 8KV
38+
- LeakyReLU(0.5)^2 MLP (3.5x)
39+
- U-Net skip connections, SmearGate, BigramHash(6144)
40+
- Partial RoPE (16/64), LN Scale, XSA on all layers
41+
- VE128 on layers 9-10
42+
- EMA(0.997) + Tight SWA
43+
- GPTQ int5 + zstd-22, 3% pruning
44+
- Late QAT with Soft-Round STE + CROWN-Q
45+
46+
## Run Command
47+
48+
```bash
49+
TWO_PASS_ENABLED=1 TWO_PASS_RESCORE_CHUNKS=15 \
50+
NGRAM_MAX_ORDER=12 NGRAM_BUCKETS=8388608 \
51+
TTT_TO_NGRAM=1 \
52+
NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=6144 XSA_LAST_N=11 \
53+
EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=1 SWA_EVERY=50 \
54+
ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 LATE_QAT_THRESHOLD=0.5 \
55+
VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 \
56+
TTT_ENABLED=1 TTT_LR=0.0005 TTT_EPOCHS=4 TTT_CHUNK_TOKENS=32768 \
57+
TTT_FREEZE_BLOCKS=2 TTT_MOMENTUM=0.9 TTT_GRAD_CLIP=1.0 \
58+
MIXER_LOSS_WEIGHT=0.1 MIXER_NEURAL_FLOOR=0.05 \
59+
SEED=1337 \
60+
torchrun --standalone --nproc_per_node=8 train_gpt.py
61+
```
62+
63+
## Credits
64+
65+
- **Learned routing head + frozen oracle**: PR #834 by @AnirudhRahul
66+
- **Two-pass rescoring**: PR #846 by @himanshudongre
67+
- **Base architecture**: PR #414 by @signalrush, PR #549 by @abaybektursun
68+
- **N-gram cache concept**: PR #659/#779 by @deanbrr
69+
- **TTT recipe**: PR #461 by @Christopher-Lee-McClendon
70+
- **LeakyReLU activation**: PR #493/#518 by @parinzee/@sofiabod
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
torch>=2.4.0
2+
numpy
3+
sentencepiece
4+
zstandard
5+
flash-attn-hopper
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"name": "Learned Routing + Two-Pass N-gram Rescoring + Extended Orders",
3+
"val_bpb": null,
4+
"bytes_total": null,
5+
"blurb": "Combines PR #834's learned multi-expert gate (Linear(512->12)) with PR #846's two-pass cold-cache rescoring. Extended n-gram orders 2-12 with 8M bucket hash tables. TTT-adapted model feeds into n-gram eval. Built on PR #834/414 stack.",
6+
"author": "pappanick",
7+
"github_id": "pappanick",
8+
"date": "2026-03-26"
9+
}

0 commit comments

Comments
 (0)