|
| 1 | +# Record: Learned Multi-Expert Gate + Frozen Oracle + Backoff TTT (3-seed mean val_bpb=0.1663) |
| 2 | + |
| 3 | +**val_bpb: 0.1663** (3-seed mean, std 0.0003) | **<16 MB** | 8xH100 SXM, 600s |
| 4 | + |
| 5 | +## Results (8xH100 80GB SXM) |
| 6 | + |
| 7 | +| Seed | Pre-TTT bpb | Post-TTT bpb | Eval time | Artifact | |
| 8 | +|------|-------------|--------------|-----------|----------| |
| 9 | +| 1337 | 1.1265 | **0.1661** | 308s | 15.74 MB | |
| 10 | +| 42 | 1.1320 | **0.1663** | 305s | 15.76 MB | |
| 11 | +| 2024 | 1.1352 | **0.1666** | 303s | 15.25 MB | |
| 12 | +| **Mean** | 1.1312 | **0.1663** | 305s | | |
| 13 | +| **Std** | | **0.0003** | | | |
| 14 | + |
| 15 | +## Background |
| 16 | + |
| 17 | +PR #779 (deanbrr) introduced the BackoffNgramMixer with entropy-adaptive alpha and drift-free TTT, achieving 0.6683 BPB. The entropy-adaptive alpha uses a hand-crafted heuristic capped at 0.60, which significantly underweights the n-gram cache when it becomes mature during later eval chunks. |
| 18 | + |
| 19 | +This submission replaces the fixed heuristic with a **learned multi-expert gate** trained end-to-end during the main training loop, and introduces a **frozen n-gram oracle** pre-computed from training data for efficient gradient-based gate training. |
| 20 | + |
| 21 | +## Technique |
| 22 | + |
| 23 | +### 1. Learned Multi-Expert Gate (Transformer Head) |
| 24 | + |
| 25 | +Instead of a fixed entropy-based alpha, we add a small `nn.Linear(model_dim, 7)` head to the GPT model that outputs per-token logits over 7 experts: |
| 26 | +- Expert 0: Neural model prediction |
| 27 | +- Experts 1-6: N-gram orders 2 through 7 |
| 28 | + |
| 29 | +The gate is trained end-to-end alongside the main language modeling objective. During the forward pass: |
| 30 | + |
| 31 | +1. Compute standard cross-entropy loss from neural logits |
| 32 | +2. Compute per-expert probabilities: `[p_neural, p_2gram, p_3gram, ..., p_7gram]` |
| 33 | +3. Apply masked softmax over valid experts (masking orders with insufficient context) |
| 34 | +4. Enforce a 5% minimum floor on the neural expert weight for stability |
| 35 | +5. Compute mixed probability: `p_mixed = sum(weights * expert_p)` |
| 36 | +6. Add mixer loss: `L_mixer = -log(p_mixed)` weighted by 0.1 |
| 37 | + |
| 38 | +The gate learns from the model's hidden state which expert to trust for each token, enabling per-token routing that a fixed heuristic cannot match. |
| 39 | + |
| 40 | +### 2. Frozen N-gram Oracle (Pre-computed from Training Data) |
| 41 | + |
| 42 | +To provide the n-gram probabilities needed for the mixer loss during training, we pre-fill the `BackoffNgramMixer` hash tables from all 80 training shards (8B tokens) at the start of training. This takes ~19 seconds and is counted within the 10-minute wallclock budget. |
| 43 | + |
| 44 | +After pre-filling, the tables are frozen — no `update()` calls during training. The alpha head sees mature n-gram statistics from step 1, enabling effective gradient-based learning throughout training. |
| 45 | + |
| 46 | +The "future token leakage" from using full-corpus statistics is negligible: any single token contributes ~1/8B = 0.000000000125 to the aggregate counts. |
| 47 | + |
| 48 | +### 3. GPU-Native BackoffNgramMixer |
| 49 | + |
| 50 | +The entire n-gram mixer operates on GPU using PyTorch tensor operations: |
| 51 | +- Count tables: `torch.int32` tensors on device (1M buckets × 2 tables × 6 orders = 48MB) |
| 52 | +- Updates via `torch.scatter_add_` (no CPU-GPU transfers) |
| 53 | +- Hash lookups via direct tensor indexing |
| 54 | + |
| 55 | +This eliminates the CPU bottleneck from the original numpy implementation. |
| 56 | + |
| 57 | +### 4. Pre-compilation of Mixer Loss Path |
| 58 | + |
| 59 | +The mixer forward+backward path is pre-compiled via `torch.compile` using dummy data before the wallclock timer starts. This avoids a ~12s JIT compilation penalty during training. The pre-compilation uses zero tensors and does not touch training data. |
| 60 | + |
| 61 | +### 5. Drift-Free TTT Configuration (from PR #779) |
| 62 | + |
| 63 | +| Parameter | Setting | |
| 64 | +|-----------|---------| |
| 65 | +| Unfrozen params | Q projections only (QTTT=1) | |
| 66 | +| Mixer eta | 0.02 | |
| 67 | +| TTT LR | 0.00003 | |
| 68 | +| Chunk size | 1M tokens (60 chunks) | |
| 69 | +| Epochs per chunk | 1 | |
| 70 | +| Adaptive LR | Disabled | |
| 71 | +| Polyak averaging | Disabled | |
| 72 | + |
| 73 | +## What the Gate Learned |
| 74 | + |
| 75 | +The expert logit statistics reveal a clear hierarchy (seed 1337): |
| 76 | + |
| 77 | +| Expert | Mean Logit | Interpretation | |
| 78 | +|--------|-----------|----------------| |
| 79 | +| Neural | -5.52 | Rarely trusted | |
| 80 | +| 2-gram | -16.78 | Almost never used | |
| 81 | +| 3-gram | -12.13 | Rarely used | |
| 82 | +| 4-gram | -8.94 | Rarely used | |
| 83 | +| 5-gram | -6.21 | Sometimes used | |
| 84 | +| 6-gram | -3.48 | Moderately used | |
| 85 | +| **7-gram** | **+8.09** | **Dominant expert** | |
| 86 | + |
| 87 | +The 7-gram expert is the only one with a positive mean logit, confirming it as the dominant predictor when the cache is mature. The gate automatically falls back to lower-order n-grams or the neural model when higher orders lack coverage. |
| 88 | + |
| 89 | +## Wallclock Budget Breakdown |
| 90 | + |
| 91 | +| Phase | Time | Inside wallclock? | |
| 92 | +|-------|------|-------------------| |
| 93 | +| Model init + warmup steps | ~25s | No | |
| 94 | +| torch.compile (standard path) | ~8s | No | |
| 95 | +| torch.compile (mixer path) | ~12s | No | |
| 96 | +| **N-gram pre-fill (8B tokens)** | **~19s** | **Yes** | |
| 97 | +| **Training (~5400 steps)** | **~562s** | **Yes** | |
| 98 | +| Eval (sliding window + TTT) | ~305s | After training | |
| 99 | + |
| 100 | +Total training wallclock: ~581s of 600s budget. |
| 101 | + |
| 102 | +## Compliance |
| 103 | + |
| 104 | +- **Score-first TTT:** Each chunk scored under `torch.inference_mode()` before any training on that chunk |
| 105 | +- **Backward-looking n-gram:** Eval-time cache counts from already-scored tokens only, updated after scoring |
| 106 | +- **N-gram pre-fill counted in wallclock:** The 19s pre-fill from training data is inside the 10-minute budget |
| 107 | +- **torch.compile outside wallclock:** Pre-compilation uses dummy data (zeros), no training tokens |
| 108 | +- **No oracle selection:** Gate depends on model hidden state, never compares mixed vs original NLL |
| 109 | +- **No training data at eval:** Eval mixer is created fresh, built causally from validation data only |
| 110 | +- **Token count verified:** ratio_scored = 1.000000 |
| 111 | +- **Artifact under 16MB:** Max 15.76 MB across seeds |
| 112 | + |
| 113 | +## Reproduction |
| 114 | + |
| 115 | +```bash |
| 116 | +pip install zstandard |
| 117 | +SEED=1337 MAX_WALLCLOCK_SECONDS=600 \ |
| 118 | +USE_MIXER=1 MIXER_ETA=0.02 MIXER_HEAD=multi \ |
| 119 | +QTTT=1 TTT_EPOCHS=1 TTT_FREEZE_BLOCKS=1 TTT_LR=0.00003 \ |
| 120 | +TTT_CHUNK_TOKENS=1048576 ADAPTIVE_LR=0 USE_POLYAK=0 \ |
| 121 | +EVAL_STRIDE=64 CROWN_Q_LAMBDA=0.01 PRUNE_PCT=0.08 \ |
| 122 | +PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ |
| 123 | +torchrun --standalone --nproc_per_node=8 train_gpt.py |
| 124 | +``` |
| 125 | + |
| 126 | +## Architecture |
| 127 | + |
| 128 | +11L, 512d, GQA 8H/8KV, MLP 3x, LeakyReLU(0.5)^2, XSA all 11 layers, Value Residual, Gated Attention, SmearGate, BigramHash(4096), Partial RoPE(16/64), LN Scale, EMA(0.997). Tied embeddings. Muon optimizer. Multi-expert gate head (Linear 512→7). ~5400 steps in 581s (19s pre-fill + 562s training). |
| 129 | + |
| 130 | +## Credits |
| 131 | + |
| 132 | +- **PR #779 deanbrr** - BackoffNgramMixer, entropy-adaptive alpha, drift-free TTT, base architecture |
| 133 | +- **PR #700 RoyiRa** - Base architecture, TTT framework, stride=64 eval |
| 134 | +- **PR #606 gowtham0992** - int5 + Soft-Round QAT model |
| 135 | +- **PR #727 Asukabot0** - Multi-order backoff concept, entropy-adaptive alpha formula |
| 136 | +- **PR #461 Christopher-Lee-McClendon** - TTT recipe foundations |
| 137 | +- **PR #518 sofiabod** - LeakyReLU(0.5)^2, cosine TTT scheduling |
0 commit comments