diff --git a/experiment_log.md b/experiment_log.md new file mode 100644 index 0000000000..ee464945bc --- /dev/null +++ b/experiment_log.md @@ -0,0 +1,429 @@ +# Parameter Golf — Full Experiment Log + +**Pod:** RTX 4000 Ada ($0.20/hr) on RunPod +**Baseline:** Cosine LR, 240 steps (600s wallclock cap), 9L/512D model = **1.6117 BPB** +**Competition SOTA:** ~1.1147 BPB (8xH100) +**Best Result:** 1.5207 BPB (H1: PureDecoder+GQA2+Untied+MatLR0.11+Parallel+SiLU²) = **-0.091 vs baseline** +**Artifact Size:** ~14MB int8+zlib (2MB headroom under 16MB budget) +**Total Experiments:** 131 + +--- + +## Phase 1: Novel Regularization Techniques (PR #1380) +*Ran on cosine LR baseline, 240 steps* + +| # | Experiment | Description | BPB | Delta | Status | +|---|-----------|-------------|-----|-------|--------| +| 1 | Z-loss 1e-4 | PaLM-style log(Z)² regularizer on logits | 1.6282 | +0.017 | Worse | +| 2 | Logit Penalty 1e-5 | L2 penalty on logit magnitudes | 1.6117 | 0.000 | Neutral | +| 3 | Token Dropout 5% | Randomly drop 5% of tokens from loss | 1.6145 | +0.003 | Worse | +| 4 | Embed Mixup 0.1 | Interpolate embedding vectors with random pairs | 1.6157 | +0.004 | Worse | +| 5 | Cosine 2-cycle | Two cosine decay cycles instead of one | 1.7383 | +0.127 | Much worse | +| 6 | Combo (Z+Logit+Drop) | Stack of Z-loss + Logit Penalty + Token Drop | 1.6171 | +0.005 | Worse | + +**Conclusion:** All novel regularization techniques hurt or were neutral. Standard cross-entropy is already well-tuned. + +--- + +## Phase 2: Nature-Inspired v1 +*Ran on cosine LR baseline, 240 steps* + +| # | Experiment | Description | BPB | Delta | Status | +|---|-----------|-------------|-----|-------|--------| +| N1 | Controlled Burn | Prune smallest 5% of weights every 50 steps | 1.6109 | -0.001 | Neutral | +| N2 | **Tidal LR** | **Golden ratio warmup (38.2% warmup, 61.8% cosine decay)** | **1.5906** | **-0.021** | **Winner** | +| N3 | Head Diversity | Quorum sensing — penalize cosine similarity between attention head projections | 1.6100 | -0.002 | Slight help | +| N4 | Weight Perturbation | Add random noise to weights every 100 steps (viral mutation) | 1.6100 | -0.002 | Slight help | +| N5 | Golden Ratio min_lr | Set min LR fraction to 0.382 (golden ratio) | 4.1069 | — | BROKEN | +| N6 | Metamorphosis | Decay dropout from 10% to 0% over first half | 4.1069 | — | BROKEN | + +**Conclusion:** Tidal LR is a clear winner. Extended warmup (38.2% of training) before cosine decay significantly helps. N5/N6 broken due to indentation bug from patching. + +--- + +## Phase 3: Unconventional Architecture (ALL BROKEN) +*torch.compile incompatible — all crashed or produced no training* + +| # | Experiment | Description | BPB | Delta | Status | +|---|-----------|-------------|-----|-------|--------| +| U1 | Layer Dropout | Randomly skip layers with 20% probability | — | — | CRASH: torch.compile | +| U2 | Weight Sharing | Share weights between layers in groups of 3 | — | — | CRASH: AttributeError | +| U3 | Progressive Growing | Start with 3 layers, add one every 60 steps | 4.1069 | — | BROKEN: no training | +| U4 | Attention Recycling | Run attention twice per layer | — | — | CRASH: torch.compile | +| U5 | Mirror Training | Reverse 50% of input sequences | — | — | Not reached | +| U6 | Head Pruning | Start 16 heads, prune to 8 at step 100 | — | — | Not reached | +| U7 | 12L Asymmetric | 12 layers, 448D, 1 encoder layer | — | — | Not reached | +| U8 | Wide MQA | 640D, 7 layers, 1 KV head | — | — | Not reached | +| U9 | Softcap+RoPE combo | Softcap 15, RoPE base 1000 | — | — | Not reached | +| U10 | Aggressive LR | embed=0.8, matrix=0.06, muon=0.98 | — | — | Not reached | + +**Conclusion:** Dynamic model structure changes are incompatible with `torch.compile(fullgraph=True)`. Complete waste of compute. Lesson: only modify training dynamics, not model architecture. + +--- + +## Phase 4: Nature-Inspired LR Schedules (Wave 1 — NF batch) +*All ran independently, 240 steps* + +| # | Experiment | Description | BPB | Delta | Status | +|---|-----------|-------------|-----|-------|--------| +| NF1 | **Breathing LR** | **4-7-8 pattern: 21% warmup, 37% steady, 42% decay** | **1.5961** | **-0.016** | **Good** | +| NF2 | Whale Dive LR | 3 dive cycles with sharp recoveries | 1.6487 | +0.037 | Worse | +| NF3 | Circadian LR | Day/night sine oscillation overlaid on cosine | 1.6141 | +0.002 | Neutral | +| NF4 | Cosmological Cooling | lr ~ 1/sqrt(1 + alpha*step), Big Bang cooling law | 1.6129 | +0.001 | Neutral | +| NF5 | Synaptic Scaling | Normalize weight norms back to initial values each step | 1.7239 | +0.112 | Very bad | +| NF6 | Mutation Decay | Exponentially decaying random weight noise | 1.6132 | +0.002 | Neutral | + +**Conclusion:** Breathing LR confirms the insight: extended high-LR phase helps. Whale Dive's cycling hurts. Synaptic scaling catastrophically bad. + +--- + +## Phase 5: Tidal LR Combinations (NF batch continued) +*Tidal LR + various techniques, 240 steps* + +| # | Experiment | Description | BPB | Delta | Status | +|---|-----------|-------------|-----|-------|--------| +| NF7 | Tidal + Controlled Burn | Tidal LR + prune 5% every 50 steps | 1.5915 | -0.020 | ~Same as Tidal | +| NF8 | **Tidal + Head Diversity** | **Tidal + quorum sensing head diversity loss (1e-4)** | **1.5867** | **-0.025** | **Better than Tidal** | +| NF9 | Tidal + Mutation Decay | Tidal + exponentially decaying noise (0.001) | 1.5902 | -0.022 | ~Same | +| NF10 | Tidal + Weight Perturb | Tidal + periodic weight perturbation (1e-4) | 1.5901 | -0.022 | ~Same | + +**Conclusion:** Head Diversity stacks well with Tidal (+0.004 on top). Other tricks are neutral on top of Tidal. + +--- + +## Phase 6: Architecture Sweeps with Tidal (NF batch continued) +*Tidal LR + architecture parameter changes, 240 steps* + +| # | Experiment | Description | BPB | Delta | Status | +|---|-----------|-------------|-----|-------|--------| +| NF11 | Deeper 12L/448D | 12 layers, 448 dim (fewer steps: 197) | 1.6620 | +0.050 | Worse | +| NF12 | Wider 640D/7L/MQA | 640 dim, 7 layers, 1 KV head (234 steps) | 1.6013 | -0.010 | Slightly better | +| NF13 | **Aggressive LR** | **embed_lr=0.8, matrix_lr=0.06** | **1.5877** | **-0.024** | **Good** | +| NF14 | **Softcap 20 + RoPE 5000** | **Logit softcap 30→20, RoPE base 10000→5000** | **1.5765** | **-0.035** | **Breakthrough** | + +**Conclusion:** Softcap 20 + RoPE 5000 is the biggest single finding. Deeper/wider models hurt because they get fewer steps in the 600s wallclock cap. + +--- + +## Phase 7: 3000-Step Verification (INVALID) +*Wallclock cap still limits to 240 steps regardless of ITERATIONS setting* + +| # | Experiment | Description | BPB | Delta | Status | +|---|-----------|-------------|-----|-------|--------| +| NF15 | Tidal @ 3000 steps | Tidal LR with ITERATIONS=3000 | 1.6504 | — | INVALID: only 240 steps ran, LR at 8% progress | +| NF16 | Cosine @ 3000 steps | Cosine LR with ITERATIONS=3000 | 1.6271 | — | INVALID: same issue | + +**Conclusion:** Setting ITERATIONS=3000 with 600s wallclock makes the LR schedule think it's at 8% progress. Tidal (38% warmup) hasn't peaked yet. Results meaningless. + +--- + +## Phase 8: Stacking All Winners (W0 batch) +*Combining best techniques, 240 steps* + +| # | Experiment | Description | BPB | Delta | Status | +|---|-----------|-------------|-----|-------|--------| +| W0a | Tidal+SC20+RoPE5k+HeadDiv | Stack top 3 findings | 1.5811 | -0.031 | Good | +| W0b | Tidal+SC20+RoPE5k+AggrLR | + embed=0.8, matrix=0.06 | 1.5774 | -0.034 | Good | +| W0c | Tidal+SC15+RoPE5k | Tighter softcap (15) | 1.5793 | -0.032 | Slightly worse than SC20 | +| W0d | Tidal+SC20+RoPE3k | Tighter RoPE (3000) | 1.5845 | -0.028 | Worse than RoPE5k | +| W0e | **EVERYTHING** | **Tidal+SC20+RoPE5k+HeadDiv+AggrLR** | **1.5744** | **-0.037** | **Best result** | + +**Conclusion:** Stacking all winners gives incremental improvement. Sweet spots confirmed: Softcap=20 (not 15), RoPE=5000 (not 3000). AggressiveLR and HeadDiv each add ~0.002. + +--- + +## Phase 9: Novel Nature Gradient Tricks (W1-W10 batch) +*All catastrophically bad* + +| # | Experiment | Description | BPB | Delta | Status | +|---|-----------|-------------|-----|-------|--------| +| W1 | Canopy Light | Scale gradients by layer depth (top=1.0x, bottom=0.3x) | 1.8587 | +0.247 | Terrible | +| W2 | Predator-Prey LR | Lotka-Volterra: LR & weight decay oscillate in opposition | 1.7981 | +0.186 | Terrible | +| W3 | Punctuated Equilibrium | Low LR + gaussian bursts at 25/50/75% progress | 1.8581 | +0.246 | Terrible | +| W4 | Mycorrhizal Gradient | Blend gradients between layers separated by 3 | 1.6223 | +0.011 | Worse | +| W5 | Thermal Vent | Random 2x-3x gradient boost to 2 random layers per step | 1.6472 | +0.036 | Worse | +| W6 | Canopy + Mycorrhizal | Rainforest combo | 1.8667 | +0.255 | Terrible | +| W7 | Tidal + Canopy | Tidal LR + depth gradient scaling | 1.5954 | -0.016 | Worse than Tidal alone | +| W8 | Tidal + Mycorrhizal | Tidal + gradient sharing | 1.6009 | -0.011 | Worse than Tidal alone | +| W9 | Tidal + Thermal Vent | Tidal + random layer boosting | 1.6124 | +0.001 | Neutral | +| W10 | Tidal + Predator-Prey WD | Tidal + oscillating weight decay | 1.6371 | +0.025 | Worse | + +**Conclusion:** ALL gradient-level nature tricks are harmful. Don't mess with gradient flow — the optimizer (Muon) is already well-tuned. Novel LR schedules work; novel gradient modifications don't. + +--- + +## Phase 10: Hyperparameter Grid (Wave 4 — partially run, killed for Wave 5) +*Only E1 completed before being killed in favor of architecture experiments* + +| # | Experiment | Description | BPB | Delta | Status | +|---|-----------|-------------|-----|-------|--------| +| E13 | Softcap 18 + RoPE 5k | Slightly tighter softcap | 1.5770 | -0.035 | ~Same as SC20 | +| E1 | Label Smoothing 0.1 | On full best config | 1.6338 | +0.059 | Worse | +| E2-E25 | Various | Killed in favor of Wave 5 | — | — | Not run | + +**Conclusion:** Label smoothing hurts on best config. Remaining experiments skipped — architecture experiments (Wave 5) were higher priority. + +--- + +## Phase 11: Architecture Experiments (Wave 5) +*Static arch changes: activation swaps, parallel blocks, sandwich norm, combos, scale* + +### Activation Functions (on Tidal+SC20+RoPE5k, no HeadDiv/AggrLR) +| # | Experiment | Description | BPB | Steps | Delta | Status | +|---|-----------|-------------|-----|-------|-------|--------| +| A1 | LeakyReLU² | Leaky negative slope 0.01, then square | 1.5820 | 240 | -0.030 | Worse than relu² | +| A2 | LeakyReLU² (cosine) | On cosine baseline | 1.6132 | 239 | +0.002 | Neutral | +| A3 | SwiGLU | Llama/Mistral standard activation | 1.5798 | 226 | -0.032 | Slower (fewer steps) | +| A4 | SwiGLU (cosine) | On cosine baseline | 1.6053 | 226 | -0.006 | Slight help but slow | +| A5 | GELU² | GELU then square | 1.5779 | 240 | -0.034 | Competitive | +| A6 | **SiLU²** | **SiLU then square** | **1.5743** | **240** | **-0.037** | **Ties prev best** | + +### Block Structure +| # | Experiment | Description | BPB | Steps | Delta | Status | +|---|-----------|-------------|-----|-------|-------|--------| +| A7 | **Parallel blocks** | **PaLM-style: attn+MLP in parallel** | **1.5600** | **260** | **-0.052** | **Big win + faster** | +| A8 | Parallel (cosine) | On cosine baseline | 1.5868 | 259 | -0.025 | Confirms parallel helps | +| A9 | Sandwich Norm | Extra RMSNorm after attention | 1.5983 | 239 | -0.013 | Minor help | + +### Combos +| # | Experiment | Description | BPB | Steps | Delta | Status | +|---|-----------|-------------|-----|-------|-------|--------| +| A10 | LeakyReLU²+Parallel | Combine activation + structure | 1.5661 | 259 | -0.046 | Worse than plain parallel | +| A11 | SwiGLU+Parallel | — | 1.5648 | 245 | -0.047 | Worse than plain parallel | +| A12 | LeakyReLU²+Best | LeakyReLU² + HeadDiv + AggrLR | 1.5802 | 240 | -0.031 | Activation hurts | +| A13 | SwiGLU+Best | SwiGLU + HeadDiv + AggrLR | 1.5848 | 226 | -0.027 | Activation + slow | +| A14 | **EVERYTHING** | **Parallel+LeakyReLU²+HeadDiv+AggrLR** | **1.5586** | **259** | **-0.053** | **New best** | + +### Scale with Architecture +| # | Experiment | Description | BPB | Steps | Delta | Status | +|---|-----------|-------------|-----|-------|-------|--------| +| A15 | LeakyReLU²+MLP3x | Wider MLP hidden layer | 1.6098 | 216 | -0.002 | Too slow | +| A16 | SwiGLU wide 576D/8L | Wider model dim | 1.6090 | 206 | -0.003 | Too slow | +| A17 | LeakyReLU²+GQA(2) | 2 KV heads | 1.5761 | 248 | -0.036 | Decent, faster | +| A18 | LeakyReLU²+MQA(1) | 1 KV head | 1.5740 | 253 | -0.038 | Good, fastest | + +**Conclusion:** Parallel blocks are the biggest single finding (-0.052). They're both faster (2315ms/step → 260 steps vs 240) AND better quality. Activation swaps are noise (±0.003). MQA is surprisingly competitive. Wider/deeper models lose to step count. + +--- + +## Phase 12: Focused Parallel Optimization (Wave 6) +*Building on parallel blocks win, 5 experiments* + +| # | Experiment | Description | BPB | Steps | Delta | Status | +|---|-----------|-------------|-----|-------|-------|--------| +| B4 | **Parallel+SiLU²+HD+AggrLR** | **Best activation + parallel** | **1.5527** | **259** | **-0.059** | **New best** | +| B5 | Parallel+QK2.0+HD+AggrLR | Higher QK gain | 1.5535 | 260 | -0.058 | Close second | +| B1 | Parallel+relu²+HD+AggrLR | Clean combo (no LeakyReLU²) | 1.5561 | 260 | -0.056 | Good | +| B3 | Parallel+10L+HD+AggrLR | Stack extras on 10L | 1.5885 | 234 | -0.023 | 10L hurts | +| B2 | Parallel+10L | Use speed budget for extra layer | 1.5979 | 233 | -0.014 | 10L hurts | + +**Conclusion:** SiLU² beats relu² with parallel (+0.003). 10 layers lose too many steps. QK gain 2.0 nearly ties best — attention strength matters slightly with parallel. + +--- + +## Phase 13: Hyperparameter Re-tune for Parallel (Wave 7) +*Re-tuning around Parallel+SiLU² architecture, 14 experiments* + +### MQA + Parallel +| # | Experiment | Description | BPB | Steps | Delta vs B4 | Status | +|---|-----------|-------------|-----|-------|-------------|--------| +| C1 | MQA+Parallel | 1 KV head + parallel | 1.5588 | 275 | +0.006 | Faster but worse | +| C2 | MQA+Par+SiLU²+Best | Full stack with MQA | 1.5533 | 274 | +0.001 | Close but no win | +| C3 | MQA+16 query heads | More query capacity | 1.6055 | 266 | +0.053 | Broken/terrible | + +### Softcap Re-tune +| # | Experiment | Description | BPB | Steps | Delta vs B4 | Status | +|---|-----------|-------------|-----|-------|-------------|--------| +| C4 | SC15 | Tighter softcap | 1.5519 | 259 | -0.001 | Slight help | +| C5 | SC25 | Looser softcap | 1.5602 | 259 | +0.008 | Worse | +| C6 | SC18 | Mid softcap | 1.5519 | 259 | -0.001 | Tied with SC15 | + +### RoPE Re-tune +| # | Experiment | Description | BPB | Steps | Delta vs B4 | Status | +|---|-----------|-------------|-----|-------|-------------|--------| +| C7 | RoPE 3000 | Tighter positions | 1.5553 | 259 | +0.003 | Worse | +| C8 | RoPE 7500 | — | 1.5544 | 259 | +0.002 | Worse | + +### LR Re-tune +| # | Experiment | Description | BPB | Steps | Delta vs B4 | Status | +|---|-----------|-------------|-----|-------|-------------|--------| +| C9 | Embed LR 1.0 | More aggressive embed | 1.5552 | 259 | +0.003 | Worse | +| C10 | **Matrix LR 0.08** | **More aggressive matrix** | **1.5501** | **259** | **-0.003** | **New best** | +| C11 | Both LR aggressive | Embed 1.0 + Matrix 0.08 | 1.5541 | 259 | +0.001 | Embed LR hurts | + +### Schedule & HeadDiv +| # | Experiment | Description | BPB | Steps | Delta vs B4 | Status | +|---|-----------|-------------|-----|-------|-------------|--------| +| C12 | Breathing LR | Alt schedule | 1.5686 | 259 | +0.016 | Tidal still better | +| C13 | HeadDiv 1e-3 | Stronger diversity | 1.5551 | 259 | +0.002 | Too strong | +| C14 | No HeadDiv | Ablation | 1.5542 | 259 | +0.002 | HD barely matters | + +**Conclusion:** Matrix LR 0.08 is a real win. SC15/18 might help but within noise. RoPE 5000 still optimal. HeadDiv barely matters with parallel. MQA branch dead. Breathing worse than Tidal. + +--- + +## Phase 14: Asymmetric Splits + Stacking (Wave 8 — in progress) +*Asymmetric encoder/decoder ratios + stacking Wave 7 wins, 10 experiments* + +| # | Experiment | Description | BPB | Delta | Status | +|---|-----------|-------------|-----|-------|--------| +### Stacking +| # | Experiment | Description | BPB | Steps | Delta vs C10 | Status | +|---|-----------|-------------|-----|-------|-------------|--------| +| D1 | MatLR0.08+SC18 | Stack two wins | 1.5501 | 259 | 0.000 | Tied | +| D2 | MatLR0.08+SC15 | Stack two wins | 1.5493 | 259 | -0.001 | Slight help | +| D3 | MatLR0.08+NoHD | Simplify config | 1.5541 | 259 | +0.004 | HD still helps | + +### Asymmetric Encoder/Decoder Splits +| # | Experiment | Description | BPB | Steps | Delta vs baseline | Status | +|---|-----------|-------------|-----|-------|-------------------|--------| +| D6 | **Asym 1/8** | **1 encoder, 8 decoder** | **1.5377** | **266** | **-0.074** | **New best** | +| D5 | Asym 2/7 | 2 encoder, 7 decoder | 1.5412 | 263 | -0.071 | Great | +| D4 | Asym 3/6 | 3 encoder, 6 decoder | 1.5439 | 262 | -0.068 | Good | +| D7 | Asym 5/4 | Control: more encoder | 1.5568 | 257 | -0.055 | Worse | + +### Stacking Asymmetric + Softcap +| # | Experiment | Description | BPB | Steps | Delta vs baseline | Status | +|---|-----------|-------------|-----|-------|-------------------|--------| +| D9 | Asym 2/7+SC18 | Stack best | 1.5440 | 263 | -0.068 | SC18 didn't help | +| D8 | Asym 3/6+SC18 | Stack best | 1.5461 | 261 | -0.066 | SC18 didn't help | +| D10 | C10 Rerun | Confirm 1.5501 | 1.5515 | 259 | -0.060 | ~0.001 noise | + +**Conclusion:** Asymmetric splits are the biggest Wave 8 finding. Monotonic: fewer encoder layers = better BPB AND faster. 1/8 split gives -0.074 total improvement. SC18 doesn't compound with asymmetric. Measurement noise is ~0.001-0.002. + +--- + +## Phase 15: Optimize 1/8 Split (Wave 9 — COMPLETE) +*Softcap, LR, activation, QK tuning on 1/8 split, 10 experiments* + +| # | Experiment | Description | BPB | Steps | Delta vs baseline | Status | +|---|-----------|-------------|-----|-------|-------------------|--------| +| E1 | SC15 on 1/8 | Softcap 15 (was 20) | 1.5387 | 265 | -0.073 | Good | +| E2 | SC18 on 1/8 | Softcap 18 | 1.5398 | 265 | -0.072 | Slightly worse | +| E3 | SC12 on 1/8 | Softcap 12 | 1.5378 | 265 | -0.074 | ~Same as D6 | +| E4 | MatLR0.10 on 1/8 | Matrix LR 0.10 (was 0.08) | 1.5392 | 265 | -0.073 | Good alone | +| E5 | MatLR0.12 on 1/8 | Matrix LR 0.12 | 1.5391 | 265 | -0.073 | Good alone | +| E6 | ReLU² on 1/8 | Swap SiLU² → ReLU² | 1.5413 | 267 | -0.070 | SiLU² better | +| E7 | QK Gain 2.0 on 1/8 | Larger QK init | 1.5400 | 265 | -0.072 | No help | +| **E8** | **SC15+MatLR0.10** | **Stack: SC15 + MatLR0.10** | **1.5354** | **265** | **-0.076** | **👑 BEST** | +| E9 | SC15+QK2 | Stack: SC15 + QK Gain 2.0 | 1.5374 | 265 | -0.074 | Worse than E8 | +| E10 | D5 Rerun | Confidence rerun (D5 base) | 1.5395 | 265 | -0.072 | Confirms noise | + +**Conclusion:** SC15+MatLR0.10 stacking gives new best: 1.5354 (-0.076). Individual effects are small (~0.001) but compound well. QK gain doesn't help. ReLU² confirmed worse than SiLU² on asymmetric. Measurement noise ~0.002. Artifact = 12.9MB int8+zlib (3.1MB headroom). + +--- + +## Phase 16: Fine-grained Tuning (Wave 10 — COMPLETE) +*Softcap step-of-1 sweep, LR fine-tune, GQA, Tidal warmup on 1/8, base: SC15+MatLR0.10* + +| # | Experiment | Description | BPB | Steps | Delta vs baseline | Status | +|---|-----------|-------------|-----|-------|-------------------|--------| +| F1 | SC13+MatLR0.10 | Softcap 13 | 1.5376 | 265 | -0.074 | Worse than E8 | +| F2 | SC14+MatLR0.10 | Softcap 14 | 1.5363 | 265 | -0.075 | Neutral | +| F3 | SC16+MatLR0.10 | Softcap 16 | 1.5359 | 266 | -0.076 | Neutral | +| F4 | SC17+MatLR0.10 | Softcap 17 | 1.5356 | 266 | -0.076 | Neutral | +| F5 | SC15+MatLR0.09 | Matrix LR 0.09 | 1.5377 | 266 | -0.074 | Worse | +| F6 | **SC15+MatLR0.11** | **Matrix LR 0.11** | **1.5341** | **266** | **-0.078** | **New best** | +| F7 | **GQA 2KV** | **2 KV heads (grouped query attention)** | **1.5329** | **276** | **-0.079** | **New best + faster** | +| F8 | Tidal 30% warmup | Shorter warmup (30% vs 38.2%) | 1.5355 | 265 | -0.076 | Neutral | +| F9 | QK Gain 2.0 stack | QK init 2.0 on E8 | 1.5401 | 265 | -0.072 | Worse | +| F10 | E8 Rerun | Confidence check | 1.5365 | 265 | -0.075 | Variance ~0.001 | + +**Conclusion:** Two wins — MatLR=0.11 and GQA with 2 KV heads. GQA is a double win: faster (2181ms/step → 276 steps vs 265) AND better quality. Softcap sweep confirms SC15 optimal but SC13-17 all within noise. Tidal 30% warmup neutral. QK gain confirmed dead. + +--- + +## Phase 17: Novel Ideas (Wave 11 — COMPLETE) +*Trimmed to 2 key experiments: untied embeddings, WD schedule* + +| # | Experiment | Description | BPB | Steps | Delta vs baseline | Status | +|---|-----------|-------------|-----|-------|-------------------|--------| +| G1 | **Untied Embeddings** | **Separate input/output embeddings (TIE_EMBEDDINGS=0)** | **1.5211** | **266** | **-0.091** | **Huge win!** | +| G3 | WD Schedule 0.01 | Ramp weight decay 0→0.01 | 1.5603 | 265 | -0.051 | Much worse | + +**Conclusion:** Untied embeddings is the biggest single improvement since parallel blocks (-0.012 on top of F7). WD schedule is catastrophically harmful — do not use. Artifact size with untied embeddings: ~14MB (still under 16MB budget). + +--- + +## Phase 18: Aggressive Experiments (Wave 12 — COMPLETE) +*Pure decoder, wider MLP, WD schedule, stacking — all with GQA2+Untied+MatLR0.11* + +| # | Experiment | Description | BPB | Steps | Delta vs baseline | Status | +|---|-----------|-------------|-----|-------|-------------------|--------| +| **H1** | **Pure Decoder** | **ENCODER_LAYERS=0, all 9 layers as decoder** | **1.5207** | **276** | **-0.091** | **Best overall** | +| H2 | MLP 3x width | MLP hidden 1024 (vs 682) | 1.5481 | 244 | -0.064 | Too slow | +| H3 | MLP 3x + Pure Decoder | 3x MLP + ENCODER_LAYERS=0 | 1.5474 | 244 | -0.064 | Too slow | +| H4 | WD Schedule 0.04 | Ramp weight decay 0→0.04 | 1.6569 | 276 | -0.045 | Catastrophic | +| H5 | All Stacked | Pure decoder + 3x MLP + WD 0.04 | 1.6974 | 244 | +0.086 | Terrible | +| H6 | Best Rerun | H1 config confidence run | 1.5214 | 276 | -0.090 | Confirms H1 | + +**Conclusion:** Pure decoder (0 encoder layers) gives a slight edge over 1/8 split. 3x MLP is too slow on RTX 4000 Ada (2462ms/step → only 244 steps, losing quality). WD schedule catastrophically bad in all forms. H6 confirms H1 is reproducible (1.5207 vs 1.5214, noise ~0.0007). + +--- + +## Summary of Key Findings (131 experiments) + +### What works (ranked by impact): +1. **Parallel blocks (PaLM-style)** — -0.052 alone, faster + better quality +2. **Untied embeddings** — -0.012 on top of best config, biggest late-stage win +3. **Pure decoder (ENCODER_LAYERS=0)** — monotonically better with fewer encoder layers +4. **GQA with 2 KV heads** — faster (2181ms vs 2264ms/step) AND better quality +5. **SiLU² activation** — best activation with parallel blocks +6. **Logit Softcap 15** (default 30) — tighter logit distribution +7. **Matrix LR 0.11** (default 0.06) — more aggressive matrix learning rate +8. **Tidal LR (golden ratio warmup)** — -0.021, 38.2% warmup before cosine decay +9. **RoPE base 5000** (default 10000) — sharper positional attention +10. **Head Diversity loss** (1e-4) — marginal but real + +### What doesn't work: +- All gradient-level tricks (Canopy, Mycorrhizal, Thermal Vent, Predator-Prey) +- All regularization (Z-loss, token dropout, embed mixup, synaptic scaling, label smoothing) +- LR cycling (Whale Dive, Cosine 2-cycle, Breathing with parallel) +- Weight decay schedule (catastrophically bad in all forms) +- 3x MLP width (too slow on single GPU — loses steps) +- Wider/deeper models at 240-step wallclock (fewer steps kills gains) +- Weight perturbation/mutation (neutral at best) +- 10 layers (even with parallel speed savings, too slow) +- MQA with 1 KV head (faster but loses quality vs GQA with 2) +- More encoder layers (5/4 worse than 4/5 worse than 1/8 worse than 0/9) +- SC18 on asymmetric splits (doesn't compound) +- Embed LR 1.0 (too aggressive) +- QK Gain 2.0 (no help on asymmetric/pure decoder) +- ReLU² on asymmetric (SiLU² better) + +### What was broken (wasted compute): +- N5, N6: Indentation bug from patching — train_loss=0.0000 +- U1-U10: torch.compile incompatible architecture changes +- E1-E12 (first run): $BEST variable not expanding as env vars in bash +- NF15-NF16: 3000-step verification invalid due to wallclock cap + +### Best Config (H1): +```bash +TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.11 ENCODER_LAYERS=0 NUM_KV_HEADS=2 TIE_EMBEDDINGS=0 +``` +**Result: 1.5207 BPB (-0.091 vs 1.6117 baseline)** +**Artifact: ~14MB int8+zlib (2MB headroom under 16MB)** +**Hardware: Single RTX 4000 Ada, 276 steps in 600s wallclock** + +### Progression: +``` +Baseline (cosine): 1.6117 ++ Tidal LR: 1.5906 (-0.021) ++ Head Diversity: 1.5867 (-0.025) ++ SC20 + RoPE 5k: 1.5744 (-0.037) ++ Parallel blocks: 1.5600 (-0.052) ++ SiLU²: 1.5527 (-0.059) ++ Matrix LR 0.08: 1.5501 (-0.062) ++ Asymmetric 1/8: 1.5377 (-0.074) ++ SC15 + MatLR 0.10: 1.5354 (-0.076) ++ MatLR 0.11: 1.5341 (-0.078) ++ GQA 2KV: 1.5329 (-0.079) ++ Untied embeddings: 1.5211 (-0.091) ++ Pure decoder (0 enc): 1.5207 (-0.091) +``` + +### Gap to Competition: +- Our best: 1.5207 (single RTX 4000 Ada, 276 steps) +- Competition baseline: 1.2244 (8xH100, ~3500 steps) +- Competition SOTA: 1.1147 (8xH100, int6 QAT + TTT + XSA + 11 layers) +- Key techniques we lack: int6 QAT, SWA/EMA, sliding window eval, 10-11 layers, BigramHash diff --git a/run_best.sh b/run_best.sh new file mode 100755 index 0000000000..941f0d887d --- /dev/null +++ b/run_best.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Run the best configuration found from 131 experiments +# Best result: 1.5207 BPB on RTX 4000 Ada (276 steps in 600s) +# Note: ITERATIONS=400 is the RTX 4000 proxy schedule horizon used in the experiments. +# The competition constraint is 10 minutes on 8xH100, not 400 fixed steps. +# +# Usage: bash run_best.sh +# Requires: GPU with PyTorch, same environment as train_gpt.py + +cd "$(dirname "$0")" + +env \ + ITERATIONS=400 \ + TIDAL_LR=1 \ + LOGIT_SOFTCAP=15.0 \ + ROPE_BASE=5000 \ + PARALLEL_BLOCK=1 \ + MLP_ACT=silu2 \ + HEAD_DIVERSITY=1e-4 \ + EMBED_LR=0.8 \ + MATRIX_LR=0.11 \ + ENCODER_LAYERS=0 \ + NUM_KV_HEADS=2 \ + TIE_EMBEDDINGS=0 \ + python train_gpt_focal_fixed.py diff --git a/train_gpt_focal_fixed.py b/train_gpt_focal_fixed.py new file mode 100644 index 0000000000..177a9770cc --- /dev/null +++ b/train_gpt_focal_fixed.py @@ -0,0 +1,1523 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + self.act_type = os.environ.get("MLP_ACT", "relu2") + if self.act_type == "swiglu": + # SwiGLU: gate and up projections, then down + hidden = int(mlp_mult * dim * 2 / 3) # reduce to keep param count similar + self.gate = CastedLinear(dim, hidden, bias=False) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + else: + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.act_type == "swiglu": + return self.proj(F.silu(self.gate(x)) * self.fc(x)) + elif self.act_type == "leaky_relu2": + x = F.leaky_relu(self.fc(x), negative_slope=0.01) + return self.proj(x.square()) + elif self.act_type == "gelu2": + x = F.gelu(self.fc(x)) + return self.proj(x.square()) + elif self.act_type == "silu2": + x = F.silu(self.fc(x)) + return self.proj(x.square()) + else: # relu2 (default) + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.parallel_block = int(os.environ.get("PARALLEL_BLOCK", "0")) + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + # Sandwich norm: extra norm after attention (if enabled) + self.sandwich = int(os.environ.get("SANDWICH_NORM", "0")) + if self.sandwich: + self.post_attn_norm = RMSNorm() + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + if self.parallel_block: + # PaLM-style: attn and MLP in parallel + normed = self.attn_norm(x) + attn_out = self.attn(normed) + mlp_out = self.mlp(normed) # share the same normed input + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * mlp_out + else: + attn_out = self.attn(self.attn_norm(x)) + if self.sandwich: + attn_out = self.post_attn_norm(attn_out) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + enc_override = int(os.environ.get("ENCODER_LAYERS", "-1")) + self.num_encoder_layers = enc_override if enc_override >= 0 else num_layers // 2 + + # Weight sharing: make layers share weights in rings + share_period = int(os.environ.get("WEIGHT_SHARE", "0")) + if share_period > 0: + for i in range(num_layers): + base = i % share_period + if base != i: + self.blocks[i] = self.blocks[base] # alias, shares parameters + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + # Mirror training: reverse some sequences to learn bidirectional patterns + mirror_prob = float(os.environ.get("MIRROR_TRAIN", "0")) + if mirror_prob > 0 and self.training: + batch_size = input_ids.shape[0] + mask = torch.rand(batch_size, device=input_ids.device) < mirror_prob + if mask.any(): + input_ids = input_ids.clone() + target_ids = target_ids.clone() + input_ids[mask] = input_ids[mask].flip(dims=[1]) + target_ids[mask] = target_ids[mask].flip(dims=[1]) + x = self.tok_emb(input_ids) + embed_scale = float(os.environ.get("EMBED_SCALE", "0")) + if embed_scale > 0: + x = x * embed_scale + # Embedding mixup: interpolate adjacent token embeddings + embed_mixup = float(os.environ.get("EMBED_MIXUP", "0")) + if embed_mixup > 0 and self.training: + shifted = torch.roll(x, 1, dims=1) + alpha = embed_mixup * torch.rand(x.shape[0], x.shape[1], 1, device=x.device, dtype=x.dtype) + x = (1 - alpha) * x + alpha * shifted + x = F.rms_norm(x, (x.size(-1),)) + # Token dropout: randomly zero out some token representations + token_drop = float(os.environ.get("TOKEN_DROP", "0")) + if token_drop > 0 and self.training: + mask = torch.rand(x.shape[0], x.shape[1], 1, device=x.device, dtype=x.dtype) > token_drop + x = x * mask / (1.0 - token_drop) # scale to preserve magnitude + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + layer_drop = float(os.environ.get("LAYER_DROP", "0")) + # Progressive growing: gradually activate more layers + prog_grow = int(os.environ.get("PROGRESSIVE_GROW", "0")) + if prog_grow and self.training: + progress = float(os.environ.get("_TRAIN_PROGRESS", "0")) + total_layers = self.num_encoder_layers + self.num_decoder_layers + active = max(3, int(total_layers * (0.3 + 0.7 * progress))) + # Only use first 'active' layers worth of encoder + # This is approximate - we scale encoder proportionally + for i in range(self.num_encoder_layers): + if layer_drop > 0 and self.training and i > 0 and torch.rand(1).item() < layer_drop: + skips.append(x) # still need to push for decoder + continue + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + if layer_drop > 0 and self.training and i > 0 and torch.rand(1).item() < layer_drop: + continue + x = self.blocks[self.num_encoder_layers + i](x, x0) + # Attention recycling: run decoder block twice + attn_recycle = int(os.environ.get("ATTN_RECYCLE", "0")) + if attn_recycle and self.training and i == self.num_decoder_layers - 1: + x = self.blocks[self.num_encoder_layers + i](x, x0) # second pass + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + # Focal loss: down-weight easy tokens, focus on hard ones (training only) + focal_gamma = float(os.environ.get("FOCAL_GAMMA", "0")) + # Gamma annealing: decay gamma from initial value to 0 over training + gamma_anneal = int(os.environ.get("GAMMA_ANNEAL", "0")) + if gamma_anneal > 0: + progress = float(os.environ.get("_TRAIN_PROGRESS", "0")) + focal_gamma = focal_gamma * (1.0 - progress) + # Z-loss regularizer (PaLM paper): penalize log(Z) to stabilize logits + z_loss_coeff = float(os.environ.get("Z_LOSS", "0")) + # Logit penalty: L2 penalty on logits to prevent overconfidence + logit_penalty = float(os.environ.get("LOGIT_PENALTY", "0")) + if focal_gamma > 0 and self.training: + ce = F.cross_entropy(logits.float(), targets, reduction="none") + pt = torch.exp(-ce) # probability of correct class + focal_weight = (1 - pt) ** focal_gamma + loss = (focal_weight * ce).mean() + else: + # Temperature scaling for sharper/softer predictions + logit_temp = float(os.environ.get("LOGIT_TEMP", "1.0")) + if logit_temp != 1.0 and self.training: + logits = logits / logit_temp + label_smooth = float(os.environ.get("LABEL_SMOOTH", "0")) + if label_smooth > 0 and self.training: + loss = F.cross_entropy(logits.float(), targets, reduction="mean", label_smoothing=label_smooth) + else: + loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if z_loss_coeff > 0 and self.training: + log_z = torch.logsumexp(logits.float(), dim=-1) + loss = loss + z_loss_coeff * (log_z ** 2).mean() + if logit_penalty > 0 and self.training: + loss = loss + logit_penalty * (logits.float() ** 2).mean() + # Multi-token prediction: auxiliary loss for next-next token + multi_tok = float(os.environ.get("MULTI_TOKEN", "0")) + if multi_tok > 0 and self.training and target_ids.shape[1] > 2: + # Shift targets by 1 more position for next-next-token prediction + shifted_targets = torch.zeros_like(target_ids) + shifted_targets[:, :-1] = target_ids[:, 1:] + shifted_targets[:, -1] = target_ids[:, -1] # pad last + shifted_targets = shifted_targets.reshape(-1) + aux_loss = F.cross_entropy(logits.float(), shifted_targets, reduction="mean") + loss = loss + multi_tok * aux_loss + + # Entropy shaping: penalize high entropy (encourage confidence) + entropy_coeff = float(os.environ.get("ENTROPY_BONUS", "0")) + if entropy_coeff != 0 and self.training: + probs = F.softmax(logits.float(), dim=-1) + log_probs = F.log_softmax(logits.float(), dim=-1) + entropy = -(probs * log_probs).sum(dim=-1).mean() + # Negative coeff = encourage low entropy (more confident) + # Positive coeff = encourage high entropy (more exploratory) + loss = loss + entropy_coeff * entropy + + # Nature: Quorum sensing — encourage attention head diversity + head_div_coeff = float(os.environ.get("HEAD_DIVERSITY", "0")) + if head_div_coeff > 0 and self.training: + # Penalize similarity between head output projections + for block in self.blocks: + if hasattr(block, 'attn') and hasattr(block.attn, 'c_proj'): + w = block.attn.c_proj.weight # (d_model, d_model) + n_heads = 8 # hardcoded for this model + head_dim = w.shape[0] // n_heads + heads = w.view(n_heads, head_dim, -1) + # Cosine similarity between head weight matrices + heads_flat = heads.reshape(n_heads, -1) + heads_norm = heads_flat / (heads_flat.norm(dim=1, keepdim=True) + 1e-8) + sim = torch.mm(heads_norm, heads_norm.t()) + # Penalize off-diagonal similarity + diversity_loss = (sim - torch.eye(n_heads, device=sim.device)).pow(2).mean() + loss = loss + head_div_coeff * diversity_loss + break # just first block to keep it cheap + return loss + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + + def lr_mul(step: int, elapsed_ms: float) -> float: + import math + + tidal = int(os.environ.get("TIDAL_LR", "0")) + if tidal: + # Golden ratio asymmetry: 38.2% warmup, 61.8% decay + warmup_frac = 0.382 + min_lr_frac = 0.1 + progress = step / max(args.iterations, 1) + if progress < warmup_frac: + return min_lr_frac + (1.0 - min_lr_frac) * (progress / warmup_frac) + else: + decay_progress = (progress - warmup_frac) / (1.0 - warmup_frac) + return min_lr_frac + (1.0 - min_lr_frac) * 0.5 * (1.0 + math.cos(math.pi * decay_progress)) + + # Breathing LR: 4-7-8 pattern (inhale-hold-exhale) + breathing = int(os.environ.get("BREATHING_LR", "0")) + if breathing: + min_lr_frac = 0.1 + progress = step / max(args.iterations, 1) + # 4 parts warmup, 7 parts steady, 8 parts decay = 19 total + warmup_end = 4.0 / 19.0 # ~21% + steady_end = 11.0 / 19.0 # ~58% + if progress < warmup_end: + return min_lr_frac + (1.0 - min_lr_frac) * (progress / warmup_end) + elif progress < steady_end: + return 1.0 + else: + decay_progress = (progress - steady_end) / (1.0 - steady_end) + return min_lr_frac + (1.0 - min_lr_frac) * 0.5 * (1.0 + math.cos(math.pi * decay_progress)) + + # Whale Dive: deep dives with sharp recoveries + whale = int(os.environ.get("WHALE_LR", "0")) + if whale: + min_lr_frac = 0.05 + progress = step / max(args.iterations, 1) + # 3 dive cycles, each deeper, with sharp recovery + cycle = progress * 3.0 + cycle_pos = cycle % 1.0 + depth = min(1.0, cycle / 3.0) # gets deeper each cycle + # Sharp rise (10% of cycle), slow descent (90%) + if cycle_pos < 0.1: + lr = min_lr_frac + (1.0 - min_lr_frac) * (cycle_pos / 0.1) * (1.0 - 0.3 * depth) + else: + descent = (cycle_pos - 0.1) / 0.9 + peak = (1.0 - min_lr_frac) * (1.0 - 0.3 * depth) + lr = min_lr_frac + peak * 0.5 * (1.0 + math.cos(math.pi * descent)) + # Final decay in last 10% + if progress > 0.9: + lr = lr * (1.0 - progress) / 0.1 + return max(lr, min_lr_frac) + + # Circadian: alternating high/low LR cycles overlaid on cosine + circadian = int(os.environ.get("CIRCADIAN_LR", "0")) + if circadian: + min_lr_frac = 0.1 + progress = step / max(args.iterations, 1) + # Base cosine schedule + base = min_lr_frac + 0.5 * (1.0 - min_lr_frac) * (1.0 + math.cos(math.pi * progress)) + # Overlay 50-step day/night cycle (amplitude decays with progress) + cycle_amplitude = 0.15 * (1.0 - progress) # smaller oscillations later + day_night = math.sin(2 * math.pi * step / 50) * cycle_amplitude + return max(min_lr_frac, base + day_night) + + # Root-First: higher embedding LR early, equalize later + # This is handled differently - via the embed_lr parameter + # We modify the training loop scale factor instead + root_first = int(os.environ.get("ROOT_FIRST", "0")) + if root_first: + min_lr_frac = 0.1 + progress = step / max(args.iterations, 1) + # Standard cosine base + base = min_lr_frac + 0.5 * (1.0 - min_lr_frac) * (1.0 + math.cos(math.pi * progress)) + return base + + # Cosmological Cooling: lr = peak / sqrt(1 + alpha*step) — Big Bang cooling law + cosmo = int(os.environ.get("COSMO_LR", "0")) + if cosmo: + min_lr_frac = 0.05 + # Warmup for first 5% + progress = step / max(args.iterations, 1) + if progress < 0.05: + return min_lr_frac + (1.0 - min_lr_frac) * (progress / 0.05) + # Cool following T ~ 1/sqrt(t), normalized so it reaches min_lr at end + alpha = 20.0 # controls cooling rate + cooling = 1.0 / math.sqrt(1.0 + alpha * progress) + # Normalize: at progress=0.05, cooling=~0.7; at progress=1.0, cooling=~0.22 + return max(min_lr_frac, cooling) + + # Coral Reef: progressively freeze bottom layers + coral = int(os.environ.get("CORAL_LR", "0")) + if coral: + # Use standard cosine for LR + min_lr_frac = 0.1 + progress = step / max(args.iterations, 1) + return min_lr_frac + 0.5 * (1.0 - min_lr_frac) * (1.0 + math.cos(math.pi * progress)) + + cosine_sched = int(os.environ.get("COSINE_LR", "0")) + if cosine_sched: + # Cosine annealing from 1.0 to 0.1, supports multi-cycle + golden = int(os.environ.get("GOLDEN_LR", "0")) + min_lr_frac = 0.382 if golden else 0.1 + num_cycles = int(os.environ.get("COSINE_CYCLES", "1")) + progress = step / max(args.iterations, 1) + cycle_progress = (progress * num_cycles) % 1.0 + return min_lr_frac + 0.5 * (1.0 - min_lr_frac) * (1.0 + math.cos(math.pi * cycle_progress)) + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + canopy_active = int(os.environ.get("CANOPY_LR", "0")) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + os.environ["_TRAIN_PROGRESS"] = str(step / max(args.iterations, 1)) + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + # Gradient Centralization: subtract mean from gradients (GC paper) + grad_central = int(os.environ.get("GRAD_CENTRAL", "0")) + if grad_central: + with torch.no_grad(): + for p in model.parameters(): + if p.grad is not None and p.ndim >= 2: + p.grad.sub_(p.grad.mean(dim=tuple(range(1, p.grad.ndim)), keepdim=True)) + + # SAM-lite: add noise in gradient direction for flatter minima + sam_rho = float(os.environ.get("SAM_RHO", "0")) + if sam_rho > 0 and step > 20: + with torch.no_grad(): + for p in model.parameters(): + if p.grad is not None and p.ndim >= 2: + # Perturbation proportional to gradient magnitude + grad_norm = p.grad.norm() + if grad_norm > 1e-8: + noise = sam_rho * p.grad / grad_norm + p.add_(noise) # will be partially undone by optimizer step + + # Weight decay schedule: ramp from 0 to target over training + wd_sched = float(os.environ.get("WD_SCHEDULE", "0")) + if wd_sched > 0 and step > 0: + progress = step / max(args.iterations, 1) + # Linear ramp: 0 -> wd_sched over training + effective_wd = wd_sched * progress + with torch.no_grad(): + for p in model.parameters(): + if p.ndim >= 2: + p.mul_(1.0 - effective_wd * scale) + + # Post-backward gradient modifications + # Canopy Light: scale gradients by layer depth (top=full sun, bottom=shade) + if canopy_active and step > 0: + with torch.no_grad(): + block_list = [m for m in model.modules() if hasattr(m, 'attn')] + n_blocks = len(block_list) + for idx, block in enumerate(block_list): + light = 0.3 + 0.7 * (idx / max(n_blocks - 1, 1)) + for p in block.parameters(): + if p.grad is not None: + p.grad.mul_(light) + + # Mycorrhizal Network: share gradients between non-adjacent layers + myco = float(os.environ.get("MYCORRHIZAL", "0")) + if myco > 0 and step > 10: + with torch.no_grad(): + block_list = [m for m in model.modules() if hasattr(m, 'attn')] + n_blocks = len(block_list) + for idx in range(n_blocks - 3): + for p1, p2 in zip(block_list[idx].parameters(), block_list[idx + 3].parameters()): + if p1.grad is not None and p2.grad is not None and p1.shape == p2.shape: + g1 = p1.grad.clone() + g2 = p2.grad.clone() + p1.grad.add_(g2, alpha=myco) + p2.grad.add_(g1, alpha=myco) + + # Thermal Vent: random layer gradient hotspots + vent = float(os.environ.get("THERMAL_VENT", "0")) + if vent > 0 and step > 10: + import random as _rng + with torch.no_grad(): + block_list = [m for m in model.modules() if hasattr(m, 'attn')] + n_blocks = len(block_list) + vent_layers = _rng.sample(range(n_blocks), min(2, n_blocks)) + for idx in vent_layers: + for p in block_list[idx].parameters(): + if p.grad is not None: + p.grad.mul_(1.0 + vent) + + # Predator-Prey: oscillating weight decay + predprey_wd = int(os.environ.get("PREDPREY_LR", "0")) + if predprey_wd and step > 0: + progress = step / max(args.iterations, 1) + cycle = math.sin(2 * math.pi * progress * 6) + base_wd = 0.01 + wd_mod = 1.0 - 0.5 * cycle * (1.0 - progress) + effective_wd = base_wd * wd_mod + with torch.no_grad(): + for p in model.parameters(): + if p.ndim >= 2: + p.mul_(1.0 - effective_wd * scale) + + # Nature: Controlled Burn — prune smallest weights periodically + burn_frac = float(os.environ.get("CONTROLLED_BURN", "0")) + burn_interval = int(os.environ.get("BURN_INTERVAL", "50")) + if burn_frac > 0 and step > 0 and step % burn_interval == 0: + with torch.no_grad(): + for p in model.parameters(): + if p.ndim >= 2: + threshold = torch.quantile(p.abs().float(), burn_frac) + mask = p.abs() >= threshold + p.mul_(mask) + # Synaptic Scaling: maintain weight norm homeostasis + synaptic = int(os.environ.get("SYNAPTIC_SCALE", "0")) + if synaptic and step > 20: # skip warmup + with torch.no_grad(): + for p in model.parameters(): + if p.ndim >= 2: + current_norm = p.norm() + if not hasattr(p, '_init_norm'): + p._init_norm = current_norm.clone() + # Gently pull toward initial norm + scale = p._init_norm / (current_norm + 1e-8) + # Don't scale more than 5% per step + scale = torch.clamp(scale, 0.95, 1.05) + p.mul_(scale) + + + # Nature: Viral Mutation — perturb weights to escape local minima + perturb_scale = float(os.environ.get("WEIGHT_PERTURB", "0")) + perturb_interval = int(os.environ.get("PERTURB_INTERVAL", "100")) + if perturb_scale > 0 and step > 0 and step % perturb_interval == 0: + with torch.no_grad(): + for p in model.parameters(): + if p.ndim >= 2: + noise = torch.randn_like(p) * perturb_scale * p.abs().mean() + p.add_(noise) + # Mutation Rate Decay: continuous decaying noise (evolution) + mutation = float(os.environ.get("MUTATION_DECAY", "0")) + if mutation > 0 and step > 20: + progress = step / max(args.iterations, 1) + # Exponential decay: high noise early, near-zero late + noise_scale = mutation * math.exp(-5.0 * progress) + if noise_scale > 1e-6: + with torch.no_grad(): + for p in model.parameters(): + if p.ndim >= 2: + noise = torch.randn_like(p) * noise_scale * p.abs().mean() + p.add_(noise) + + + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/wave10_finetune.sh b/wave10_finetune.sh new file mode 100644 index 0000000000..ca8db72383 --- /dev/null +++ b/wave10_finetune.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# Wave 10: Fine-grained tuning around best config from Wave 9 +# This script assumes 1/8 asymmetric is confirmed best. +# Fine-tune softcap and matrix LR with small increments. +cd /workspace/parameter-golf +LOG="/workspace/wave10_results.log" + +echo "=== WAVE 10: FINE-TUNE $(date) ===" > $LOG +echo "Building on best from Wave 9" >> $LOG +echo "" >> $LOG + +grab() { + local name="$1" + local logfile=$(ls -t /workspace/parameter-golf/logs/*.txt | head -1) + local result=$(grep "^step:.*val_bpb" "$logfile" | tail -1) + echo "$result" >> $LOG + echo "END: $(date)" >> $LOG + echo "" >> $LOG + sleep 2 + pkill -9 -f train_gpt_focal 2>/dev/null + sleep 3 +} + +run() { + local name="$1" + shift + echo "--- $name ---" >> $LOG + echo "START: $(date)" >> $LOG + env ITERATIONS=400 "$@" python train_gpt_focal_fixed.py > "/workspace/${name}.txt" 2>&1 + grab "$name" +} + +# ============================================ +# PHASE 1: FINE-GRAINED SOFTCAP SWEEP ON 1/8 +# If SC15 won Wave 9, sweep 13-17 in steps of 1 +# ============================================ +echo "========== SOFTCAP FINE SWEEP ===========" >> $LOG + +# Sweep softcap around SC15 with MatLR=0.10 (E8 base) +run "F1_SC13_MatLR10" TIDAL_LR=1 LOGIT_SOFTCAP=13.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.10 ENCODER_LAYERS=1 +run "F2_SC14_MatLR10" TIDAL_LR=1 LOGIT_SOFTCAP=14.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.10 ENCODER_LAYERS=1 +run "F3_SC16_MatLR10" TIDAL_LR=1 LOGIT_SOFTCAP=16.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.10 ENCODER_LAYERS=1 +run "F4_SC17_MatLR10" TIDAL_LR=1 LOGIT_SOFTCAP=17.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.10 ENCODER_LAYERS=1 + +# ============================================ +# PHASE 2: FINE-GRAINED MATRIX LR SWEEP ON 1/8 +# ============================================ +echo "========== MATRIX LR FINE SWEEP ===========" >> $LOG + +# E8 showed SC15+MatLR0.10 = 1.5354. Sweep around MatLR0.10 with SC15. +run "F5_SC15_MatLR009" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.09 ENCODER_LAYERS=1 +run "F6_SC15_MatLR011" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.11 ENCODER_LAYERS=1 + +# ============================================ +# PHASE 3: GQA ON 1/8 +# 2 KV heads was decent before (A17 = 1.5761) +# With asymmetric it might be different. +# ============================================ +echo "========== GQA ON 1/8 ===========" >> $LOG + +run "F7_GQA2" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.10 ENCODER_LAYERS=1 NUM_KV_HEADS=2 + +# ============================================ +# PHASE 4: TIDAL WARMUP RATIO ON 1/8 +# Default Tidal = 38.2% warmup. With more decoder +# layers and faster steps, maybe different ratio helps. +# ============================================ +echo "========== TIDAL VARIANT ===========" >> $LOG + +# Try 30% warmup (shorter warmup, more time at high LR) +run "F8_Tidal30" TIDAL_LR=1 TIDAL_WARMUP=0.30 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.10 ENCODER_LAYERS=1 + +# ============================================ +# PHASE 5: STACK BEST COMBO +# Combine the best softcap + best LR from above +# ============================================ +echo "========== FINAL STACK ===========" >> $LOG + +# Stack: E8 config (SC15+MatLR0.10) + QK2.0 +run "F9_E8_QK2" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.10 ENCODER_LAYERS=1 QK_GAIN_INIT=2.0 + +# Rerun E8 config for confidence +run "F10_E8_Rerun" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.10 ENCODER_LAYERS=1 + +echo "" >> $LOG +echo "=== WAVE 10 COMPLETE $(date) ===" >> $LOG +cat $LOG diff --git a/wave11_final.sh b/wave11_final.sh new file mode 100644 index 0000000000..e1ab6f9ed9 --- /dev/null +++ b/wave11_final.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Wave 11: Trimmed to 2 key experiments only +cd /workspace/parameter-golf +LOG="/workspace/wave11_results.log" + +echo "=== WAVE 11: TRIMMED $(date) ===" > $LOG +echo "Only 2 key experiments" >> $LOG +echo "" >> $LOG + +grab() { + local name="$1" + local logfile=$(ls -t /workspace/parameter-golf/logs/*.txt | head -1) + local result=$(grep "^step:.*val_bpb" "$logfile" | tail -1) + echo "$result" >> $LOG + echo "END: $(date)" >> $LOG + echo "" >> $LOG + sleep 2 + pkill -9 -f train_gpt_focal 2>/dev/null + sleep 3 +} + +run() { + local name="$1" + shift + echo "--- $name ---" >> $LOG + echo "START: $(date)" >> $LOG + env ITERATIONS=400 "$@" python train_gpt_focal_fixed.py > "/workspace/${name}.txt" 2>&1 + grab "$name" +} + +# G1: Untied embeddings — novel, could be big win +run "G1_Untied" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.11 ENCODER_LAYERS=1 TIE_EMBEDDINGS=0 + +# G3: WD Schedule — competition winners use this +run "G3_WDSched" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.11 ENCODER_LAYERS=1 WD_SCHEDULE=0.01 + +echo "" >> $LOG +echo "=== WAVE 11 COMPLETE $(date) ===" >> $LOG +cat $LOG diff --git a/wave12_aggressive.sh b/wave12_aggressive.sh new file mode 100644 index 0000000000..8672c5e1ae --- /dev/null +++ b/wave12_aggressive.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# Wave 12: Aggressive experiments — pure decoder, wider MLP, model scaling +# Uses findings from Waves 9-11 + 3.1MB artifact headroom +cd /workspace/parameter-golf +LOG="/workspace/wave12_results.log" + +echo "=== WAVE 12: AGGRESSIVE $(date) ===" > $LOG +echo "Pure decoder, wider MLP, scaling experiments" >> $LOG +echo "" >> $LOG + +grab() { + local name="$1" + local logfile=$(ls -t /workspace/parameter-golf/logs/*.txt | head -1) + local result=$(grep "^step:.*val_bpb" "$logfile" | tail -1) + echo "$result" >> $LOG + echo "END: $(date)" >> $LOG + echo "" >> $LOG + sleep 2 + pkill -9 -f train_gpt_focal 2>/dev/null + sleep 3 +} + +run() { + local name="$1" + shift + echo "--- $name ---" >> $LOG + echo "START: $(date)" >> $LOG + env ITERATIONS=400 "$@" python train_gpt_focal_fixed.py > "/workspace/${name}.txt" 2>&1 + grab "$name" +} + +# ============================================ +# PHASE 1: PURE DECODER (ENCODER_LAYERS=0) +# Bug fixed: default=-1, so ENCODER_LAYERS=0 now works +# All 9 layers as decoder, no encoder skip connections +# ============================================ +echo "========== PURE DECODER ===========" >> $LOG + +# H1: Pure decoder with best config (E8 base) +run "H1_PureDecoder" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.11 ENCODER_LAYERS=0 NUM_KV_HEADS=2 TIE_EMBEDDINGS=0 + +# ============================================ +# PHASE 2: WIDER MLP (3x) +# Top competition entries use 3x MLP width +# MLP_MULT=3 + SiLU² → hidden=1024 (vs current 682) +# More params but potentially much better quality +# Will be slower per step but might make up in quality +# ============================================ +echo "========== WIDER MLP ===========" >> $LOG + +# H2: 3x MLP on best config (1/8 split) +run "H2_MLP3x" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.11 ENCODER_LAYERS=1 MLP_MULT=3 NUM_KV_HEADS=2 TIE_EMBEDDINGS=0 + +# H3: 3x MLP + pure decoder +run "H3_MLP3x_PureDec" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.11 ENCODER_LAYERS=0 MLP_MULT=3 NUM_KV_HEADS=2 TIE_EMBEDDINGS=0 + +# ============================================ +# PHASE 3: WD SCHEDULE (only env var that exists) +# WD_SCHEDULE ramps weight decay from 0 to target +# Competition winners use WD=0.04 +# ============================================ +echo "========== WEIGHT DECAY SCHEDULE ===========" >> $LOG + +# H4: WD schedule ramping to 0.04 +run "H4_WDSched04" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.11 ENCODER_LAYERS=1 WD_SCHEDULE=0.04 NUM_KV_HEADS=2 TIE_EMBEDDINGS=0 + +# ============================================ +# PHASE 4: STACK WINNERS FROM ABOVE +# ============================================ +echo "========== STACK ===========" >> $LOG + +# H5: Pure decoder + 3x MLP + WD (aggressive combo) +run "H5_AllStack" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.11 ENCODER_LAYERS=0 MLP_MULT=3 WD_SCHEDULE=0.04 NUM_KV_HEADS=2 TIE_EMBEDDINGS=0 + +# H6: Best config rerun for final confidence +run "H6_BestRerun" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.11 ENCODER_LAYERS=1 NUM_KV_HEADS=2 TIE_EMBEDDINGS=0 + +echo "" >> $LOG +echo "=== WAVE 12 COMPLETE $(date) ===" >> $LOG +cat $LOG diff --git a/wave5_arch.sh b/wave5_arch.sh new file mode 100644 index 0000000000..6f68a4fd12 --- /dev/null +++ b/wave5_arch.sh @@ -0,0 +1,107 @@ +#!/bin/bash +# Wave 5: Architecture experiments — static changes, torch.compile safe +cd /workspace/parameter-golf +LOG="/workspace/wave5_results.log" + +echo "=== WAVE 5: ARCHITECTURE $(date) ===" > $LOG +echo "BASELINE: Cosine = 1.6117 | Best: 1.5744 (Tidal+SC20+RoPE5k+HD+AggrLR)" >> $LOG +echo "" >> $LOG + +grab() { + local name="$1" + local logfile=$(ls -t /workspace/parameter-golf/logs/*.txt | head -1) + local result=$(grep "^step:.*val_bpb" "$logfile" | tail -1) + echo "$result" >> $LOG + echo "END: $(date)" >> $LOG + echo "" >> $LOG + sleep 2 + pkill -9 -f train_gpt_focal 2>/dev/null + sleep 3 +} + +run() { + local name="$1" + shift + echo "--- $name ---" >> $LOG + echo "START: $(date)" >> $LOG + env ITERATIONS=400 "$@" python train_gpt_focal_fixed.py > "/workspace/${name}.txt" 2>&1 + grab "$name" +} + +# ============================================ +# PHASE 0: ACTIVATION FUNCTIONS (on best config) +# ============================================ +echo "========== ACTIVATION FUNCTIONS ==========" >> $LOG + +# A1: LeakyReLU² — on the leaderboard! (#2 entry uses this) +run "A1_LeakyReLU2" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 MLP_ACT=leaky_relu2 + +# A2: LeakyReLU² on cosine baseline (to isolate activation effect) +run "A2_LeakyReLU2_cosine" COSINE_LR=1 MLP_ACT=leaky_relu2 + +# A3: SwiGLU — standard in Llama/Mistral +run "A3_SwiGLU" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 MLP_ACT=swiglu + +# A4: SwiGLU on cosine baseline +run "A4_SwiGLU_cosine" COSINE_LR=1 MLP_ACT=swiglu + +# A5: GELU² — smoother than ReLU² +run "A5_GELU2" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 MLP_ACT=gelu2 + +# A6: SiLU² — like SwiGLU but without gate +run "A6_SiLU2" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 MLP_ACT=silu2 + +# ============================================ +# PHASE 1: BLOCK STRUCTURE +# ============================================ +echo "========== BLOCK STRUCTURE ==========" >> $LOG + +# A7: Parallel attention + MLP (PaLM-style) +run "A7_Parallel" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 + +# A8: Parallel on cosine +run "A8_Parallel_cosine" COSINE_LR=1 PARALLEL_BLOCK=1 + +# A9: Sandwich norm (extra norm after attention) +run "A9_Sandwich" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 SANDWICH_NORM=1 + +# ============================================ +# PHASE 2: ACTIVATION + STRUCTURE COMBOS +# ============================================ +echo "========== COMBOS ==========" >> $LOG + +# A10: LeakyReLU² + Parallel (combine two arch changes) +run "A10_LeakyParallel" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 MLP_ACT=leaky_relu2 PARALLEL_BLOCK=1 + +# A11: SwiGLU + Parallel +run "A11_SwiGLUParallel" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 MLP_ACT=swiglu PARALLEL_BLOCK=1 + +# A12: LeakyReLU² + full best config +run "A12_LeakyBest" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 MLP_ACT=leaky_relu2 + +# A13: SwiGLU + full best config +run "A13_SwiGLUBest" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 MLP_ACT=swiglu + +# A14: LeakyReLU² + Parallel + full best (EVERYTHING) +run "A14_EVERYTHING" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 MLP_ACT=leaky_relu2 PARALLEL_BLOCK=1 + +# ============================================ +# PHASE 3: WIDER/DEEPER WITH ARCH CHANGES +# ============================================ +echo "========== SCALE WITH ARCH ==========" >> $LOG + +# A15: LeakyReLU² + wider MLP (3x) +run "A15_LeakyMLP3" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 MLP_ACT=leaky_relu2 MLP_MULT=3 + +# A16: SwiGLU + wider dim (since SwiGLU has fewer params per layer) +run "A16_SwiGLU_wide" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 MLP_ACT=swiglu MODEL_DIM=576 NUM_LAYERS=8 + +# A17: LeakyReLU² + 2 KV heads (more aggressive GQA) +run "A17_LeakyGQA2" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 MLP_ACT=leaky_relu2 NUM_KV_HEADS=2 + +# A18: LeakyReLU² + 1 KV head (MQA) +run "A18_LeakyMQA" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 MLP_ACT=leaky_relu2 NUM_KV_HEADS=1 + +echo "" >> $LOG +echo "=== WAVE 5 COMPLETE $(date) ===" >> $LOG +cat $LOG diff --git a/wave6_focused.sh b/wave6_focused.sh new file mode 100644 index 0000000000..8d21427aac --- /dev/null +++ b/wave6_focused.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# Wave 6: Focused experiments — building on Parallel blocks win +# Each experiment has a clear hypothesis +cd /workspace/parameter-golf +LOG="/workspace/wave6_results.log" + +echo "=== WAVE 6: FOCUSED $(date) ===" > $LOG +echo "BASELINE: Cosine = 1.6117 | Best: A14 = 1.5586 (Parallel+Leaky+HD+Aggr)" >> $LOG +echo "Best Parallel alone: A7 = 1.5600" >> $LOG +echo "" >> $LOG + +grab() { + local name="$1" + local logfile=$(ls -t /workspace/parameter-golf/logs/*.txt | head -1) + local result=$(grep "^step:.*val_bpb" "$logfile" | tail -1) + echo "$result" >> $LOG + echo "END: $(date)" >> $LOG + echo "" >> $LOG + sleep 2 + pkill -9 -f train_gpt_focal 2>/dev/null + sleep 3 +} + +run() { + local name="$1" + shift + echo "--- $name ---" >> $LOG + echo "START: $(date)" >> $LOG + env ITERATIONS=400 "$@" python train_gpt_focal_fixed.py > "/workspace/${name}.txt" 2>&1 + grab "$name" +} + +# ============================================ +# B1: Parallel + relu² + HeadDiv + AggrLR +# Hypothesis: A14 used LeakyReLU² which hurts. +# Plain relu² parallel + full extras should beat 1.5586. +# ============================================ +run "B1_ParallelBest" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 + +# ============================================ +# B2: Parallel + 10 layers (9L default) +# Hypothesis: Parallel saves ~200ms/step (2315 vs 2510). +# 10L parallel should be ~2570ms → ~234 steps. +# More capacity at similar step count = lower BPB. +# ============================================ +run "B2_Parallel10L" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 NUM_LAYERS=10 + +# ============================================ +# B3: Parallel + 10L + HeadDiv + AggrLR +# Hypothesis: Stack the best extras on 10L parallel. +# ============================================ +run "B3_Parallel10LBest" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 NUM_LAYERS=10 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 + +# ============================================ +# B4: Parallel + SiLU² + HeadDiv + AggrLR +# Hypothesis: SiLU² was best activation (1.5743) at same +# speed as relu². With parallel it won't lose steps. +# Tests if SiLU² + parallel + extras can beat B1. +# ============================================ +run "B4_ParallelSiLU2Best" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 + +# ============================================ +# B5: Parallel + QK gain 2.0 +# Hypothesis: Default qk_gain_init=1.5. Parallel shares +# normed input between attn+MLP, so stronger attention +# signal (higher gain) might help differentiate. +# ============================================ +run "B5_ParallelQK2" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 QK_GAIN_INIT=2.0 + +echo "" >> $LOG +echo "=== WAVE 6 COMPLETE $(date) ===" >> $LOG +cat $LOG diff --git a/wave7_retune.sh b/wave7_retune.sh new file mode 100644 index 0000000000..383f45ec95 --- /dev/null +++ b/wave7_retune.sh @@ -0,0 +1,120 @@ +#!/bin/bash +# Wave 7: Re-tune hyperparameters around Parallel+SiLU² architecture +# Best config: PARALLEL_BLOCK=1 MLP_ACT=silu2 + Tidal+SC20+RoPE5k+HD+AggrLR = 1.5527 +# Hypothesis: hyperparams were tuned for sequential blocks. Parallel changes +# gradient flow, so optimal softcap/RoPE/LR/etc may have shifted. +cd /workspace/parameter-golf +LOG="/workspace/wave7_results.log" + +echo "=== WAVE 7: RETUNE $(date) ===" > $LOG +echo "BASELINE: Cosine = 1.6117 | Best: B4 = 1.5527 (Parallel+SiLU2+HD+Aggr)" >> $LOG +echo "" >> $LOG + +# Base config for all experiments (B4 winner) +# TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 + +grab() { + local name="$1" + local logfile=$(ls -t /workspace/parameter-golf/logs/*.txt | head -1) + local result=$(grep "^step:.*val_bpb" "$logfile" | tail -1) + echo "$result" >> $LOG + echo "END: $(date)" >> $LOG + echo "" >> $LOG + sleep 2 + pkill -9 -f train_gpt_focal 2>/dev/null + sleep 3 +} + +run() { + local name="$1" + shift + echo "--- $name ---" >> $LOG + echo "START: $(date)" >> $LOG + env ITERATIONS=400 "$@" python train_gpt_focal_fixed.py > "/workspace/${name}.txt" 2>&1 + grab "$name" +} + +# ============================================ +# PHASE 1: MQA + PARALLEL (untested combo) +# A18 showed MQA is fast (2379ms) + good (1.5740) +# Parallel is fast (2315ms) + good (1.5600) +# Together = even faster = even more steps? +# ============================================ +echo "========== MQA + PARALLEL ===========" >> $LOG + +# C1: MQA + Parallel (basic) +run "C1_MQA_Parallel" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 NUM_KV_HEADS=1 + +# C2: MQA + Parallel + SiLU² + full extras +run "C2_MQA_ParSiLU2Best" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 NUM_KV_HEADS=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 + +# C3: MQA + Parallel + more heads (16 query heads, 1 KV head) +# More query heads = more capacity without KV cost +run "C3_MQA_16H" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 NUM_KV_HEADS=1 NUM_HEADS=16 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 + +# ============================================ +# PHASE 2: SOFTCAP RE-TUNE FOR PARALLEL +# SC20 was optimal for sequential. Parallel shares +# normed input → logit distribution may differ. +# ============================================ +echo "========== SOFTCAP RETUNE ===========" >> $LOG + +# C4: Softcap 15 (tighter) on best parallel config +run "C4_SC15" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 + +# C5: Softcap 25 (looser) on best parallel config +run "C5_SC25" TIDAL_LR=1 LOGIT_SOFTCAP=25.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 + +# C6: Softcap 18 on best parallel config +run "C6_SC18" TIDAL_LR=1 LOGIT_SOFTCAP=18.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 + +# ============================================ +# PHASE 3: ROPE RE-TUNE FOR PARALLEL +# ============================================ +echo "========== ROPE RETUNE ===========" >> $LOG + +# C7: RoPE 3000 (tighter positional attention) +run "C7_RoPE3k" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=3000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 + +# C8: RoPE 7500 +run "C8_RoPE7500" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=7500 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 + +# ============================================ +# PHASE 4: LR TUNING FOR PARALLEL +# Parallel changes gradient flow — may want different LRs +# ============================================ +echo "========== LR RETUNE ===========" >> $LOG + +# C9: More aggressive embed LR +run "C9_EmbLR1" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=1.0 MATRIX_LR=0.06 + +# C10: More aggressive matrix LR +run "C10_MatLR008" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 + +# C11: Both more aggressive +run "C11_AggrLR2" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=1.0 MATRIX_LR=0.08 + +# ============================================ +# PHASE 5: TIDAL WARMUP RE-TUNE +# 38.2% warmup was optimal for sequential. +# Parallel converges differently — try other ratios. +# ============================================ +echo "========== WARMUP RETUNE ===========" >> $LOG + +# C12: Breathing LR instead of Tidal (4-7-8 pattern) +run "C12_Breathing" BREATHING_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.06 + +# ============================================ +# PHASE 6: HEAD DIVERSITY STRENGTH +# ============================================ +echo "========== HEAD DIV TUNE ===========" >> $LOG + +# C13: Stronger head diversity +run "C13_HD1e3" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-3 EMBED_LR=0.8 MATRIX_LR=0.06 + +# C14: No head diversity (to confirm it still helps with parallel) +run "C14_NoHD" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 EMBED_LR=0.8 MATRIX_LR=0.06 + +echo "" >> $LOG +echo "=== WAVE 7 COMPLETE $(date) ===" >> $LOG +cat $LOG diff --git a/wave8_asymmetry.sh b/wave8_asymmetry.sh new file mode 100644 index 0000000000..8a054bb638 --- /dev/null +++ b/wave8_asymmetry.sh @@ -0,0 +1,91 @@ +#!/bin/bash +# Wave 8: Asymmetric splits + stacking Wave 7 wins +# Best: C10 = 1.5501 (Parallel+SiLU²+HD+MatLR0.08+Tidal+SC20+RoPE5k) +# New base = B4 config but with MATRIX_LR=0.08 +cd /workspace/parameter-golf +LOG="/workspace/wave8_results.log" + +echo "=== WAVE 8: ASYMMETRY $(date) ===" > $LOG +echo "BASELINE: 1.6117 | Best: C10 = 1.5501 (MatLR0.08)" >> $LOG +echo "Default split: 4 encoder / 5 decoder (9 layers)" >> $LOG +echo "" >> $LOG + +grab() { + local name="$1" + local logfile=$(ls -t /workspace/parameter-golf/logs/*.txt | head -1) + local result=$(grep "^step:.*val_bpb" "$logfile" | tail -1) + echo "$result" >> $LOG + echo "END: $(date)" >> $LOG + echo "" >> $LOG + sleep 2 + pkill -9 -f train_gpt_focal 2>/dev/null + sleep 3 +} + +run() { + local name="$1" + shift + echo "--- $name ---" >> $LOG + echo "START: $(date)" >> $LOG + env ITERATIONS=400 "$@" python train_gpt_focal_fixed.py > "/workspace/${name}.txt" 2>&1 + grab "$name" +} + +# ============================================ +# PHASE 1: ESTABLISH NEW BEST (stack C10 wins) +# ============================================ +echo "========== STACKING ===========" >> $LOG + +# D1: C10 config + SC18 (both were individual wins) +# Hypothesis: SC18 tied SC15 for best softcap, MatLR0.08 was best LR. +# Stacking should compound. +run "D1_MatLR08_SC18" TIDAL_LR=1 LOGIT_SOFTCAP=18.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 + +# D2: C10 config + SC15 +run "D2_MatLR08_SC15" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 + +# D3: C10 config but NO HeadDiv (C14 showed it barely matters) +# If this ties C10, we simplify the config. +run "D3_MatLR08_NoHD" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 EMBED_LR=0.8 MATRIX_LR=0.08 + +# ============================================ +# PHASE 2: ASYMMETRIC ENCODER/DECODER SPLIT +# Default is 4/5. More decoder layers = more generation capacity. +# ============================================ +echo "========== ASYMMETRIC SPLITS ===========" >> $LOG + +# D4: 3 encoder / 6 decoder on C10 config +# Hypothesis: More decoder capacity helps generation quality. +run "D4_Asym36" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=3 + +# D5: 2 encoder / 7 decoder on C10 config +# Hypothesis: Push even harder toward decoder. +run "D5_Asym27" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=2 + +# D6: 1 encoder / 8 decoder (extreme) +# Hypothesis: If 2/7 helps, push to the limit. +run "D6_Asym18" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=1 + +# D7: 5 encoder / 4 decoder (opposite direction — more encoder) +# Hypothesis: Control experiment. If more encoder helps, our assumption is wrong. +run "D7_Asym54" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=5 + +# ============================================ +# PHASE 3: BEST ASYMMETRIC + BEST SOFTCAP +# Stack if both asymmetry and softcap help +# ============================================ +echo "========== STACK BEST ===========" >> $LOG + +# D8: Best asymmetric (3/6) + SC18 + MatLR0.08 +# Only run if D4 shows improvement over D1 +run "D8_Asym36_SC18" TIDAL_LR=1 LOGIT_SOFTCAP=18.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=3 + +# D9: Best asymmetric (2/7) + SC18 + MatLR0.08 +run "D9_Asym27_SC18" TIDAL_LR=1 LOGIT_SOFTCAP=18.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=2 + +# D10: Rerun C10 (confirm 1.5501 is real, not noise) +run "D10_C10_Rerun" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 + +echo "" >> $LOG +echo "=== WAVE 8 COMPLETE $(date) ===" >> $LOG +cat $LOG diff --git a/wave9_optimize.sh b/wave9_optimize.sh new file mode 100644 index 0000000000..44f6017279 --- /dev/null +++ b/wave9_optimize.sh @@ -0,0 +1,96 @@ +#!/bin/bash +# Wave 9: Optimize around asymmetric 1/8 split +# Best so far: D5 = 1.5412 (Asym 1/8 + Parallel + SiLU² + HD + MatLR0.08 + SC20) +# D2 showed SC15 helps. D8/D9 test SC18 on asymmetric. +# This wave: fine-tune everything around 1/8. +cd /workspace/parameter-golf +LOG="/workspace/wave9_results.log" + +echo "=== WAVE 9: OPTIMIZE $(date) ===" > $LOG +echo "BASELINE: 1.6117 | Best: D6 = 1.5377 (Asym 1/8)" >> $LOG +echo "" >> $LOG + +grab() { + local name="$1" + local logfile=$(ls -t /workspace/parameter-golf/logs/*.txt | head -1) + local result=$(grep "^step:.*val_bpb" "$logfile" | tail -1) + echo "$result" >> $LOG + echo "END: $(date)" >> $LOG + echo "" >> $LOG + sleep 2 + pkill -9 -f train_gpt_focal 2>/dev/null + sleep 3 +} + +run() { + local name="$1" + shift + echo "--- $name ---" >> $LOG + echo "START: $(date)" >> $LOG + env ITERATIONS=400 "$@" python train_gpt_focal_fixed.py > "/workspace/${name}.txt" 2>&1 + grab "$name" +} + +# ============================================ +# PHASE 1: SOFTCAP ON 1/8 +# SC15 helped on 4/5 split. Test on 1/8. +# ============================================ +echo "========== SOFTCAP ON 1/8 ===========" >> $LOG + +# E1: 1/8 + SC15 (best softcap from Wave 7) +run "E1_Asym27_SC15" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=1 + +# E2: 1/8 + SC18 +run "E2_Asym27_SC18" TIDAL_LR=1 LOGIT_SOFTCAP=18.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=1 + +# E3: 1/8 + SC12 (push tighter) +run "E3_Asym27_SC12" TIDAL_LR=1 LOGIT_SOFTCAP=12.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=1 + +# ============================================ +# PHASE 2: LR TUNING ON 1/8 +# MatLR0.08 was better than 0.06. Try more. +# ============================================ +echo "========== LR ON 1/8 ===========" >> $LOG + +# E4: 1/8 + Matrix LR 0.10 +run "E4_Asym27_MatLR10" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.10 ENCODER_LAYERS=1 + +# E5: 1/8 + Matrix LR 0.12 +run "E5_Asym27_MatLR12" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.12 ENCODER_LAYERS=1 + +# ============================================ +# PHASE 3: ACTIVATION ON 1/8 +# SiLU² won on 4/5. Does it still win on 1/8? +# ============================================ +echo "========== ACTIVATION ON 1/8 ===========" >> $LOG + +# E6: 1/8 with plain relu² (ablation) +run "E6_Asym27_ReLU2" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=1 + +# ============================================ +# PHASE 4: QK GAIN ON 1/8 +# B5 nearly tied best on 4/5. Try on 1/8. +# ============================================ +echo "========== QK GAIN ON 1/8 ===========" >> $LOG + +# E7: 1/8 + QK gain 2.0 +run "E7_Asym27_QK2" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=1 QK_GAIN_INIT=2.0 + +# ============================================ +# PHASE 5: STACK ALL WINS ON 1/8 +# Combine best softcap + best LR + best QK +# ============================================ +echo "========== STACK BEST ON 1/8 ===========" >> $LOG + +# E8: 1/8 + SC15 + MatLR0.10 (if both help individually) +run "E8_Asym27_SC15_MatLR10" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.10 ENCODER_LAYERS=1 + +# E9: 1/8 + SC15 + QK2.0 +run "E9_Asym27_SC15_QK2" TIDAL_LR=1 LOGIT_SOFTCAP=15.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=1 QK_GAIN_INIT=2.0 + +# E10: Rerun D5 to confirm 1.5412 +run "E10_D5_Rerun" TIDAL_LR=1 LOGIT_SOFTCAP=20.0 ROPE_BASE=5000 PARALLEL_BLOCK=1 MLP_ACT=silu2 HEAD_DIVERSITY=1e-4 EMBED_LR=0.8 MATRIX_LR=0.08 ENCODER_LAYERS=1 + +echo "" >> $LOG +echo "=== WAVE 9 COMPLETE $(date) ===" >> $LOG +cat $LOG