Skip to content

Commit ba665dd

Browse files
abaybektursunclaude
andcommitted
Record: Fused Triton MLP + Brotli + Turbo-Muon + Memmap pipeline
Continuation of PR 1019. Stack three hardware/systems optimizations: 1. Fused Triton MLP kernel (fwd-only): -17ms/step projected on 8xH100. Builds on our kernel profiling in PR 670. 2. Brotli-11 compression: -581KB (-5.9%) vs LZMA-9. Independently discovered. 3. Turbo-Muon (AOL + Polar Express NS4): better convergence quality per step. 4. Memmap multi-shard data pipeline + GPU prefetch from PR 726. 8xH100 results pending. 2xH100 validated: -8ms/step (-2.5%), -0.012 BPB. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 630bb5e commit ba665dd

File tree

3 files changed

+2795
-0
lines changed

3 files changed

+2795
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Record: Fused Triton MLP + Brotli Compression + Turbo-Muon
2+
3+
Continuation of PR 1019. Three hardware/systems optimizations stacked on top of our previous best (1.1147 BPB).
4+
5+
## Changes vs PR 1019
6+
7+
### 1. Fused Triton MLP Kernel (forward-only)
8+
Fuses `F.linear(x, up_w) -> LeakyReLU(0.5) -> square` into a single Triton TMA kernel. The 302MB intermediate activation per layer never touches HBM. Backward uses explicit cuBLAS matmuls — this avoids the Inductor bypass issue we identified in our PR 670, where fusing fwd+bwd via `torch.autograd.Function` was 2.7x slower net because the backward ran in eager mode.
9+
10+
- Builds on our kernel profiling work in PR 670 (abaybektursun); key insight: fuse only the forward MLP up-projection, keep backward as explicit cuBLAS matmuls
11+
- Validated: -8ms/step on 2xH100 (-2.5%), projects to ~-17ms on 8xH100
12+
13+
### 2. Brotli-11 Compression (replaces LZMA-9)
14+
Drop-in replacement. Brotli quality=11 saves 581 KB (-5.9%) vs LZMA preset=9 on int6 quantized weights. Byte-shuffle tested and found to provide no additional benefit.
15+
16+
- Independently discovered; also used in PR 1089 (mikeapedia)
17+
- Frees headroom for more BigramHash buckets or mixed bit allocation
18+
19+
### 3. Turbo-Muon (AOL + Polar Express + NS4)
20+
Replaces standard 5-iteration Newton-Schulz with AOL-preconditioned 4-iteration variant using Polar Express per-iteration optimal coefficients (Amsel et al., arXiv:2505.16932). Post-NS row/col L2 normalization for stability.
21+
22+
- From PR 1089 (mikeapedia)
23+
- Drop-in replacement for `zeropower_via_newtonschulz5()` inside existing Parallel Muon
24+
- Neutral throughput on 2xH100 (AOL cost ~= 1 NS iteration saved), but better convergence quality (-0.044 nats train loss at step 500)
25+
26+
### 4. Memmap Multi-Shard Data Pipeline + GPU Prefetch
27+
Coprime-stride sampling across multiple shards with daemon thread CPU batch building and CUDA stream prefetch. Better data diversity per batch.
28+
29+
- From PR 726 (DeepReinforce)
30+
- Already validated in our stack
31+
32+
## Architecture (unchanged from PR 1019)
33+
- 11L/512d, 8 heads, 4 KV heads, GQA
34+
- LeakyReLU(0.5)^2, MLP 3x (1536)
35+
- XSA on all 11 layers
36+
- BigramHash 3072 x dim=112
37+
- Partial RoPE 16/64, LN 1/sqrt(layer+1)
38+
- VE128 layers 9-10, SmearGate, U-Net skips
39+
- EMA(0.997) + SWA(every 50), Late QAT (STE at scale<0.15)
40+
- Full Hessian GPTQ int6 (AR self-gen calibration)
41+
- Sliding window eval stride=64
42+
43+
## 2xH100 Validation Results (PENDING 8xH100)
44+
45+
| Metric | PR 1019 baseline | This work | Delta |
46+
|---|---|---|---|
47+
| Step avg (2xH100) | 327.51ms | ~319ms | -8ms (-2.5%) |
48+
| Artifact size | 10.09 MB | ~9.5 MB (Brotli) | -0.6 MB |
49+
50+
8xH100 results pending.
51+
52+
## Run command
53+
```bash
54+
BIGRAM_VOCAB_SIZE=3072 BIGRAM_DIM=112 WARMDOWN_ITERS=4000 SEED=314 \
55+
torchrun --standalone --nproc_per_node=8 train_gpt.py
56+
```
57+
58+
## Requirements
59+
```bash
60+
pip install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128
61+
pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291
62+
pip install sentencepiece brotli
63+
```
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# FlashAttention 3 must be installed separately; see README.md
2+
sentencepiece
3+
brotli>=1.1

0 commit comments

Comments
 (0)