|
| 1 | +# Record: Fused Triton MLP + Brotli Compression + Turbo-Muon |
| 2 | + |
| 3 | +Continuation of PR 1019. Three hardware/systems optimizations stacked on top of our previous best (1.1147 BPB). |
| 4 | + |
| 5 | +## Changes vs PR 1019 |
| 6 | + |
| 7 | +### 1. Fused Triton MLP Kernel (forward-only) |
| 8 | +Fuses `F.linear(x, up_w) -> LeakyReLU(0.5) -> square` into a single Triton TMA kernel. The 302MB intermediate activation per layer never touches HBM. Backward uses explicit cuBLAS matmuls — this avoids the Inductor bypass issue we identified in our PR 670, where fusing fwd+bwd via `torch.autograd.Function` was 2.7x slower net because the backward ran in eager mode. |
| 9 | + |
| 10 | +- Builds on our kernel profiling work in PR 670 (abaybektursun); key insight: fuse only the forward MLP up-projection, keep backward as explicit cuBLAS matmuls |
| 11 | +- Validated: -8ms/step on 2xH100 (-2.5%), projects to ~-17ms on 8xH100 |
| 12 | + |
| 13 | +### 2. Brotli-11 Compression (replaces LZMA-9) |
| 14 | +Drop-in replacement. Brotli quality=11 saves 581 KB (-5.9%) vs LZMA preset=9 on int6 quantized weights. Byte-shuffle tested and found to provide no additional benefit. |
| 15 | + |
| 16 | +- Independently discovered; also used in PR 1089 (mikeapedia) |
| 17 | +- Frees headroom for more BigramHash buckets or mixed bit allocation |
| 18 | + |
| 19 | +### 3. Turbo-Muon (AOL + Polar Express + NS4) |
| 20 | +Replaces standard 5-iteration Newton-Schulz with AOL-preconditioned 4-iteration variant using Polar Express per-iteration optimal coefficients (Amsel et al., arXiv:2505.16932). Post-NS row/col L2 normalization for stability. |
| 21 | + |
| 22 | +- From PR 1089 (mikeapedia) |
| 23 | +- Drop-in replacement for `zeropower_via_newtonschulz5()` inside existing Parallel Muon |
| 24 | +- Neutral throughput on 2xH100 (AOL cost ~= 1 NS iteration saved), but better convergence quality (-0.044 nats train loss at step 500) |
| 25 | + |
| 26 | +### 4. Memmap Multi-Shard Data Pipeline + GPU Prefetch |
| 27 | +Coprime-stride sampling across multiple shards with daemon thread CPU batch building and CUDA stream prefetch. Better data diversity per batch. |
| 28 | + |
| 29 | +- From PR 726 (DeepReinforce) |
| 30 | +- Already validated in our stack |
| 31 | + |
| 32 | +## Architecture (unchanged from PR 1019) |
| 33 | +- 11L/512d, 8 heads, 4 KV heads, GQA |
| 34 | +- LeakyReLU(0.5)^2, MLP 3x (1536) |
| 35 | +- XSA on all 11 layers |
| 36 | +- BigramHash 3072 x dim=112 |
| 37 | +- Partial RoPE 16/64, LN 1/sqrt(layer+1) |
| 38 | +- VE128 layers 9-10, SmearGate, U-Net skips |
| 39 | +- EMA(0.997) + SWA(every 50), Late QAT (STE at scale<0.15) |
| 40 | +- Full Hessian GPTQ int6 (AR self-gen calibration) |
| 41 | +- Sliding window eval stride=64 |
| 42 | + |
| 43 | +## 2xH100 Validation Results (PENDING 8xH100) |
| 44 | + |
| 45 | +| Metric | PR 1019 baseline | This work | Delta | |
| 46 | +|---|---|---|---| |
| 47 | +| Step avg (2xH100) | 327.51ms | ~319ms | -8ms (-2.5%) | |
| 48 | +| Artifact size | 10.09 MB | ~9.5 MB (Brotli) | -0.6 MB | |
| 49 | + |
| 50 | +8xH100 results pending. |
| 51 | + |
| 52 | +## Run command |
| 53 | +```bash |
| 54 | +BIGRAM_VOCAB_SIZE=3072 BIGRAM_DIM=112 WARMDOWN_ITERS=4000 SEED=314 \ |
| 55 | +torchrun --standalone --nproc_per_node=8 train_gpt.py |
| 56 | +``` |
| 57 | + |
| 58 | +## Requirements |
| 59 | +```bash |
| 60 | +pip install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128 |
| 61 | +pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291 |
| 62 | +pip install sentencepiece brotli |
| 63 | +``` |
0 commit comments