diff --git a/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/README.md b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/README.md new file mode 100644 index 0000000000..2d7e46491f --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/README.md @@ -0,0 +1,72 @@ +# 11L LatentMask TTT + GPTQ + Product-Key Bigram + Brotli + +**val_bpb: 1.1124 (3-seed mean)** | Artifact: ≤15,994,891 bytes | 8xH100, 600s training + ~455s eval + +## Summary + +11-layer GPT with U-Net skip connections, achieving 1.1124 val_bpb (3-seed mean) through four key innovations: + +1. **LatentMask TTT (Test-Time Training)**: Per-channel sigmoid masks + biases on MLP and attention outputs, trained per-chunk during evaluation using a sign-based Muon-lite optimizer. Score-first (legal): each chunk is scored before any mask update. Provides ~−0.002 bpb improvement over sliding window eval (without TTT). + +2. **Full Hessian GPTQ**: Hessian-aware int6 quantization with Cholesky error compensation and column reordering. Uses autoregressive self-generated calibration data (32 sequences × 2048 tokens). Reduces quantization error vs per-row percentile search. + +3. **Product-Key Bigram Embedding**: Factored bigram via `embed_prev(1024,512) * embed_cur(1024,512)` — zero hash collision, no projection layer needed. Replaces traditional hash-based bigram embedding. + +4. **Brotli-11 Compression**: Custom binary serialization (JSON header + raw tensor bytes) compressed with Brotli quality=11. Combined with uint8 log-scale quantization for per-row scales. + +## Architecture + +- 11 layers, 512 dim, 8 heads, 4 KV heads (GQA) +- MLP 3x (1536 hidden), LeakyReLU(0.5)² +- GatedAttention on even layers [0,2,4,6,8,10] +- XSA (Exclusive Self-Attention) on all 11 layers +- Value Embeddings at decoder layers [5,7,10] +- U-Net encoder-decoder skip connections +- SmearGate for adjacent token mixing +- Tied embeddings, logit softcap=30 + +## Training + +- Muon optimizer (matrix params), AdamW (scalar/embed params) +- matrix_lr=0.028, muon_wd=0.0417 +- EMA (decay=0.997), Late QAT (threshold=0.15) +- Warmdown: 3500 steps (time-based adaptive) +- ~6,555 steps in 600s, step_avg ~91.5ms (H100, Flash Attention 3) + +## Quantization & Compression + +- int6 per-row for MLP/attention weights (GPTQ when Hessian available) +- int8 per-row for embeddings +- uint8 log-scale for per-row scales (2B → 1B per scale) +- Custom binary serialization + Brotli-11 + +## Evaluation + +- LatentMask TTT: lr=0.0008, chunk=65536 tokens, epochs=4, momentum=0.9 +- Sliding window stride=64, seq_len=2048 +- TTT eval time: ~455s (H100) + +## Results (3-seed, 8xH100) + +| Seed | val_bpb | Steps | Step Avg (ms) | Artifact (bytes) | +|------|---------|-------|---------------|-------------------| +| 777 | 1.11195 | 6,555 | 91.54 | 15,985,742 | +| 999 | 1.11218 | 6,555 | 91.54 | 15,994,891 | +| 1337 | 1.11297 | 6,556 | 91.52 | 15,988,042 | +| **Mean** | **1.11237** | **6,555** | **91.53** | | + +## Dependencies + +``` +pip install flash-attn-3 brotli +``` + +Flash Attention 3 (Hopper kernels) is required. The script imports `flash_attn_interface` directly. + +(`sentencepiece`, `torch`, `numpy` assumed pre-installed) + +## Run + +```bash +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/requirements.txt b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/requirements.txt new file mode 100644 index 0000000000..3b8e1cbcc6 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/requirements.txt @@ -0,0 +1,2 @@ +flash-attn +brotli diff --git a/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/submission.json b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/submission.json new file mode 100644 index 0000000000..b24e27a6b7 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/submission.json @@ -0,0 +1,45 @@ +{ + "author": "JungyupLee", + "github_id": "izlley", + "name": "LatentMask TTT + Full Hessian GPTQ + Product-Key Bigram + Brotli", + "blurb": "11L GPT with U-Net skips, LatentMask TTT (per-channel sigmoid masks, score-first legal), Full Hessian GPTQ (AR self-gen calib, Cholesky error compensation), Product-Key Bigram (zero collision), GatedAttention on even layers, XSA on all layers, Value Embeddings at decoder [5,7,10], Muon optimizer, EMA+SWA, Late QAT, Brotli-11 + uint8 log-scale compression.", + "date": "2026-04-05", + "track": "10min_16mb", + "val_loss": 1.87817944, + "val_bpb": 1.11236660, + "val_loss_std": 0.00073910, + "val_bpb_std": 0.00043774, + "seeds": [777, 999, 1337], + "seed_results": { + "777": { + "val_loss": 1.87746873, + "val_bpb": 1.11194568, + "artifact_bytes": 15985742, + "steps": 6555, + "step_avg_ms": 91.54 + }, + "999": { + "val_loss": 1.87787103, + "val_bpb": 1.11218395, + "artifact_bytes": 15994891, + "steps": 6555, + "step_avg_ms": 91.54 + }, + "1337": { + "val_loss": 1.87919855, + "val_bpb": 1.11297018, + "artifact_bytes": 15988042, + "steps": 6556, + "step_avg_ms": 91.52 + } + }, + "artifact_bytes_max": 15994891, + "bytes_total": 15994891, + "code_bytes": 61917, + "model_bytes": 15932974, + "train_steps_mean": 6555.3, + "step_avg_ms_mean": 91.53, + "hardware": "8xH100 80GB SXM (RunPod)", + "flash_attn_version": "3.0.0+20260303.cu128torch291cxx11abitrue.ceb109", + "technique_summary": "LatentMask TTT + Full Hessian GPTQ (AR self-gen) + Product-Key Bigram + GatedAttn-even + XSA-all + VE-decoder + Muon + Brotli-11" +} diff --git a/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_1337.log b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_1337.log new file mode 100644 index 0000000000..167db836d3 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_1337.log @@ -0,0 +1,171 @@ +W0407 13:54:02.290000 68766 torch/distributed/run.py:803] +W0407 13:54:02.290000 68766 torch/distributed/run.py:803] ***************************************** +W0407 13:54:02.290000 68766 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0407 13:54:02.290000 68766 torch/distributed/run.py:803] ***************************************** +logs/4efa9d48-88d4-4ff4-aff3-f3fa722476d7.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27739277 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +attention_backend:flash_attn_3 +attention_mode:gqa num_heads:8 num_kv_heads:4 gated_attention:True +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.028 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9313 train_time:149ms step_avg:149.01ms +step:2/20000 train_loss:8.7077 train_time:235ms step_avg:117.28ms +step:3/20000 train_loss:7.7451 train_time:324ms step_avg:108.03ms +step:4/20000 train_loss:7.1697 train_time:414ms step_avg:103.47ms +step:5/20000 train_loss:6.9473 train_time:503ms step_avg:100.66ms +step:6/20000 train_loss:6.8905 train_time:593ms step_avg:98.76ms +step:7/20000 train_loss:6.8465 train_time:682ms step_avg:97.41ms +step:8/20000 train_loss:6.7056 train_time:772ms step_avg:96.47ms +step:9/20000 train_loss:6.4039 train_time:861ms step_avg:95.61ms +step:10/20000 train_loss:6.0594 train_time:950ms step_avg:95.01ms +step:500/20000 train_loss:2.3535 train_time:45721ms step_avg:91.44ms +step:1000/20000 train_loss:2.2504 train_time:91556ms step_avg:91.56ms +step:1500/20000 train_loss:2.1935 train_time:137382ms step_avg:91.59ms +step:2000/20000 train_loss:2.0431 train_time:183157ms step_avg:91.58ms +step:2500/20000 train_loss:2.1503 train_time:228933ms step_avg:91.57ms +step:3000/20000 train_loss:2.1441 train_time:274696ms step_avg:91.57ms +step:3500/20000 train_loss:2.1507 train_time:320508ms step_avg:91.57ms +step:4000/20000 train_loss:1.9396 train_time:366267ms step_avg:91.57ms +step:4500/20000 train_loss:2.0902 train_time:411981ms step_avg:91.55ms +step:5000/20000 train_loss:2.0725 train_time:457729ms step_avg:91.55ms +step:5500/20000 train_loss:1.9867 train_time:503453ms step_avg:91.54ms +step:6000/20000 train_loss:1.9023 train_time:549179ms step_avg:91.53ms +late_qat:enabled step:6031 scale:0.1499 +step:6500/20000 train_loss:2.0427 train_time:594883ms step_avg:91.52ms +step:6556/20000 val_loss:1.9138 val_bpb:1.1335 train_time:599973ms step_avg:91.52ms +stopping_early: wallclock_cap train_time:599973ms step:6556/20000 +peak memory allocated: 22185 MiB reserved: 22230 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9122 val_bpb:1.1325 eval_time:2117ms +Serialized model: 107548579 bytes +Code size: 61917 bytes +custom_serialize: raw=27888454 bytes +Serialized model int6+custom+: 15926125 bytes +Total submission size: 15988042 bytes +supermask_ttt:start chunks=947 chunk_tokens=65536 total_windows=969088 stride=64 ttt_lr=0.0008 ttt_epochs=4 +supermask_ttt:params mask_params=45056 (blocks=11 per_block=mlp_mask+mlp_bias+attn_mask+attn_bias) model_frozen=27739277 + ttt_chunk [1/947] bpb=1.159986 time=0.5s + ttt_chunk [11/947] bpb=1.126072 time=5.3s + ttt_chunk [21/947] bpb=1.097464 time=10.1s + ttt_chunk [31/947] bpb=1.108908 time=14.8s + ttt_chunk [41/947] bpb=1.102337 time=19.6s + ttt_chunk [51/947] bpb=1.113493 time=24.3s + ttt_chunk [61/947] bpb=1.108278 time=29.1s + ttt_chunk [71/947] bpb=1.108788 time=33.8s + ttt_chunk [81/947] bpb=1.115921 time=38.6s + ttt_chunk [91/947] bpb=1.118991 time=43.3s + ttt_chunk [101/947] bpb=1.120177 time=48.1s + ttt_chunk [111/947] bpb=1.119387 time=52.9s + ttt_chunk [121/947] bpb=1.118636 time=57.6s + ttt_chunk [131/947] bpb=1.114143 time=62.4s + ttt_chunk [141/947] bpb=1.115884 time=67.1s + ttt_chunk [151/947] bpb=1.117421 time=71.9s + ttt_chunk [161/947] bpb=1.121499 time=76.6s + ttt_chunk [171/947] bpb=1.121852 time=81.3s + ttt_chunk [181/947] bpb=1.125681 time=86.1s + ttt_chunk [191/947] bpb=1.126671 time=90.8s + ttt_chunk [201/947] bpb=1.124873 time=95.5s + ttt_chunk [211/947] bpb=1.122965 time=100.3s + ttt_chunk [221/947] bpb=1.122447 time=105.0s + ttt_chunk [231/947] bpb=1.122880 time=109.7s + ttt_chunk [241/947] bpb=1.124736 time=114.5s + ttt_chunk [251/947] bpb=1.123316 time=119.2s + ttt_chunk [261/947] bpb=1.121324 time=123.9s + ttt_chunk [271/947] bpb=1.122645 time=128.7s + ttt_chunk [281/947] bpb=1.120706 time=133.4s + ttt_chunk [291/947] bpb=1.120303 time=138.2s + ttt_chunk [301/947] bpb=1.120288 time=142.9s + ttt_chunk [311/947] bpb=1.121292 time=147.6s + ttt_chunk [321/947] bpb=1.122122 time=152.4s + ttt_chunk [331/947] bpb=1.121531 time=157.1s + ttt_chunk [341/947] bpb=1.122221 time=161.8s + ttt_chunk [351/947] bpb=1.122854 time=166.6s + ttt_chunk [361/947] bpb=1.123091 time=171.3s + ttt_chunk [371/947] bpb=1.123278 time=176.1s + ttt_chunk [381/947] bpb=1.122532 time=180.8s + ttt_chunk [391/947] bpb=1.121911 time=185.6s + ttt_chunk [401/947] bpb=1.121561 time=190.3s + ttt_chunk [411/947] bpb=1.120589 time=195.1s + ttt_chunk [421/947] bpb=1.119598 time=199.8s + ttt_chunk [431/947] bpb=1.120465 time=204.6s + ttt_chunk [441/947] bpb=1.119705 time=209.3s + ttt_chunk [451/947] bpb=1.119311 time=214.1s + ttt_chunk [461/947] bpb=1.119551 time=218.8s + ttt_chunk [471/947] bpb=1.118941 time=223.5s + ttt_chunk [481/947] bpb=1.118406 time=228.3s + ttt_chunk [491/947] bpb=1.119079 time=233.0s + ttt_chunk [501/947] bpb=1.118907 time=237.7s + ttt_chunk [511/947] bpb=1.119250 time=242.5s + ttt_chunk [521/947] bpb=1.119526 time=247.3s + ttt_chunk [531/947] bpb=1.119334 time=252.0s + ttt_chunk [541/947] bpb=1.118905 time=256.7s + ttt_chunk [551/947] bpb=1.119547 time=261.5s + ttt_chunk [561/947] bpb=1.118974 time=266.2s + ttt_chunk [571/947] bpb=1.119412 time=271.0s + ttt_chunk [581/947] bpb=1.118122 time=275.7s + ttt_chunk [591/947] bpb=1.118695 time=280.4s + ttt_chunk [601/947] bpb=1.119685 time=285.2s + ttt_chunk [611/947] bpb=1.118961 time=289.9s + ttt_chunk [621/947] bpb=1.119009 time=294.6s + ttt_chunk [631/947] bpb=1.119142 time=299.3s + ttt_chunk [641/947] bpb=1.118264 time=304.1s + ttt_chunk [651/947] bpb=1.117396 time=308.8s + ttt_chunk [661/947] bpb=1.116765 time=313.6s + ttt_chunk [671/947] bpb=1.116670 time=318.3s + ttt_chunk [681/947] bpb=1.117126 time=323.0s + ttt_chunk [691/947] bpb=1.116652 time=327.8s + ttt_chunk [701/947] bpb=1.115895 time=332.5s + ttt_chunk [711/947] bpb=1.115840 time=337.3s + ttt_chunk [721/947] bpb=1.116325 time=342.0s + ttt_chunk [731/947] bpb=1.115883 time=346.7s + ttt_chunk [741/947] bpb=1.117077 time=351.5s + ttt_chunk [751/947] bpb=1.117106 time=356.2s + ttt_chunk [761/947] bpb=1.117370 time=360.9s + ttt_chunk [771/947] bpb=1.117251 time=365.7s + ttt_chunk [781/947] bpb=1.117942 time=370.4s + ttt_chunk [791/947] bpb=1.118484 time=375.2s + ttt_chunk [801/947] bpb=1.118591 time=379.9s + ttt_chunk [811/947] bpb=1.118499 time=384.6s + ttt_chunk [821/947] bpb=1.118419 time=389.4s + ttt_chunk [831/947] bpb=1.118590 time=394.1s + ttt_chunk [841/947] bpb=1.119483 time=398.8s + ttt_chunk [851/947] bpb=1.119201 time=403.6s + ttt_chunk [861/947] bpb=1.118602 time=408.3s + ttt_chunk [871/947] bpb=1.118322 time=413.1s + ttt_chunk [881/947] bpb=1.118486 time=417.8s + ttt_chunk [891/947] bpb=1.118402 time=422.6s + ttt_chunk [901/947] bpb=1.118221 time=427.3s + ttt_chunk [911/947] bpb=1.117833 time=432.1s + ttt_chunk [921/947] bpb=1.117337 time=436.8s + ttt_chunk [931/947] bpb=1.116661 time=441.6s + ttt_chunk [941/947] bpb=1.116105 time=446.3s + ttt_chunk [947/947] bpb=1.115820 time=448.8s +supermask_ttt:done val_loss=1.879199 val_bpb=1.112970 elapsed=448.8s +legal_ttt val_loss:1.8792 val_bpb:1.1130 eval_time:449323ms +legal_ttt_exact val_loss:1.87919855 val_bpb:1.11297018 diff --git a/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_777.log b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_777.log new file mode 100644 index 0000000000..d86d1318cc --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_777.log @@ -0,0 +1,171 @@ +W0407 13:05:35.743000 2263 torch/distributed/run.py:803] +W0407 13:05:35.743000 2263 torch/distributed/run.py:803] ***************************************** +W0407 13:05:35.743000 2263 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0407 13:05:35.743000 2263 torch/distributed/run.py:803] ***************************************** +logs/ecad83be-5825-41b3-af20-ea8c7a382f9b.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27739277 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +attention_backend:flash_attn_3 +attention_mode:gqa num_heads:8 num_kv_heads:4 gated_attention:True +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.028 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:777 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9314 train_time:150ms step_avg:150.27ms +step:2/20000 train_loss:8.6618 train_time:237ms step_avg:118.69ms +step:3/20000 train_loss:7.7141 train_time:327ms step_avg:108.96ms +step:4/20000 train_loss:7.2976 train_time:416ms step_avg:103.92ms +step:5/20000 train_loss:7.0479 train_time:505ms step_avg:101.00ms +step:6/20000 train_loss:6.8727 train_time:594ms step_avg:99.08ms +step:7/20000 train_loss:6.7676 train_time:684ms step_avg:97.69ms +step:8/20000 train_loss:6.6761 train_time:773ms step_avg:96.61ms +step:9/20000 train_loss:6.3446 train_time:862ms step_avg:95.76ms +step:10/20000 train_loss:5.9990 train_time:951ms step_avg:95.13ms +step:500/20000 train_loss:2.3570 train_time:46080ms step_avg:92.16ms +step:1000/20000 train_loss:2.2413 train_time:91879ms step_avg:91.88ms +step:1500/20000 train_loss:2.1902 train_time:137650ms step_avg:91.77ms +step:2000/20000 train_loss:2.0397 train_time:183498ms step_avg:91.75ms +step:2500/20000 train_loss:2.1474 train_time:229258ms step_avg:91.70ms +step:3000/20000 train_loss:2.1389 train_time:275011ms step_avg:91.67ms +step:3500/20000 train_loss:2.1457 train_time:320739ms step_avg:91.64ms +step:4000/20000 train_loss:1.9431 train_time:366461ms step_avg:91.62ms +step:4500/20000 train_loss:2.0923 train_time:412179ms step_avg:91.60ms +step:5000/20000 train_loss:2.0687 train_time:457911ms step_avg:91.58ms +step:5500/20000 train_loss:1.9832 train_time:503627ms step_avg:91.57ms +step:6000/20000 train_loss:1.9042 train_time:549334ms step_avg:91.56ms +late_qat:enabled step:6029 scale:0.1499 +step:6500/20000 train_loss:2.0409 train_time:595024ms step_avg:91.54ms +step:6555/20000 val_loss:1.9122 val_bpb:1.1325 train_time:600028ms step_avg:91.54ms +stopping_early: wallclock_cap train_time:600028ms step:6555/20000 +peak memory allocated: 22187 MiB reserved: 22284 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9106 val_bpb:1.1315 eval_time:2108ms +Serialized model: 107548579 bytes +Code size: 61917 bytes +custom_serialize: raw=27888454 bytes +Serialized model int6+custom+: 15923825 bytes +Total submission size: 15985742 bytes +supermask_ttt:start chunks=947 chunk_tokens=65536 total_windows=969088 stride=64 ttt_lr=0.0008 ttt_epochs=4 +supermask_ttt:params mask_params=45056 (blocks=11 per_block=mlp_mask+mlp_bias+attn_mask+attn_bias) model_frozen=27739277 + ttt_chunk [1/947] bpb=1.161563 time=0.5s + ttt_chunk [11/947] bpb=1.124145 time=5.3s + ttt_chunk [21/947] bpb=1.096537 time=10.2s + ttt_chunk [31/947] bpb=1.108618 time=15.0s + ttt_chunk [41/947] bpb=1.101933 time=19.8s + ttt_chunk [51/947] bpb=1.113039 time=24.7s + ttt_chunk [61/947] bpb=1.107444 time=29.4s + ttt_chunk [71/947] bpb=1.107968 time=34.2s + ttt_chunk [81/947] bpb=1.115269 time=38.9s + ttt_chunk [91/947] bpb=1.118312 time=43.6s + ttt_chunk [101/947] bpb=1.119406 time=48.4s + ttt_chunk [111/947] bpb=1.118601 time=53.1s + ttt_chunk [121/947] bpb=1.117776 time=57.9s + ttt_chunk [131/947] bpb=1.113267 time=62.6s + ttt_chunk [141/947] bpb=1.114955 time=67.4s + ttt_chunk [151/947] bpb=1.116440 time=72.2s + ttt_chunk [161/947] bpb=1.120459 time=77.0s + ttt_chunk [171/947] bpb=1.120754 time=81.8s + ttt_chunk [181/947] bpb=1.124719 time=86.7s + ttt_chunk [191/947] bpb=1.125787 time=91.5s + ttt_chunk [201/947] bpb=1.124007 time=96.3s + ttt_chunk [211/947] bpb=1.122106 time=101.1s + ttt_chunk [221/947] bpb=1.121585 time=105.9s + ttt_chunk [231/947] bpb=1.122089 time=110.7s + ttt_chunk [241/947] bpb=1.123954 time=115.6s + ttt_chunk [251/947] bpb=1.122508 time=120.4s + ttt_chunk [261/947] bpb=1.120641 time=125.2s + ttt_chunk [271/947] bpb=1.121913 time=130.0s + ttt_chunk [281/947] bpb=1.120007 time=134.8s + ttt_chunk [291/947] bpb=1.119600 time=139.6s + ttt_chunk [301/947] bpb=1.119593 time=144.4s + ttt_chunk [311/947] bpb=1.120564 time=149.2s + ttt_chunk [321/947] bpb=1.121390 time=154.0s + ttt_chunk [331/947] bpb=1.120859 time=158.8s + ttt_chunk [341/947] bpb=1.121454 time=163.7s + ttt_chunk [351/947] bpb=1.122070 time=168.5s + ttt_chunk [361/947] bpb=1.122238 time=173.3s + ttt_chunk [371/947] bpb=1.122375 time=178.1s + ttt_chunk [381/947] bpb=1.121707 time=183.0s + ttt_chunk [391/947] bpb=1.121054 time=187.8s + ttt_chunk [401/947] bpb=1.120729 time=192.6s + ttt_chunk [411/947] bpb=1.119810 time=197.4s + ttt_chunk [421/947] bpb=1.118861 time=202.2s + ttt_chunk [431/947] bpb=1.119671 time=207.1s + ttt_chunk [441/947] bpb=1.118928 time=211.9s + ttt_chunk [451/947] bpb=1.118475 time=216.7s + ttt_chunk [461/947] bpb=1.118745 time=221.5s + ttt_chunk [471/947] bpb=1.118205 time=226.4s + ttt_chunk [481/947] bpb=1.117716 time=231.2s + ttt_chunk [491/947] bpb=1.118402 time=236.0s + ttt_chunk [501/947] bpb=1.118213 time=240.8s + ttt_chunk [511/947] bpb=1.118543 time=245.7s + ttt_chunk [521/947] bpb=1.118793 time=250.5s + ttt_chunk [531/947] bpb=1.118638 time=255.3s + ttt_chunk [541/947] bpb=1.118206 time=260.1s + ttt_chunk [551/947] bpb=1.118806 time=264.9s + ttt_chunk [561/947] bpb=1.118247 time=269.7s + ttt_chunk [571/947] bpb=1.118638 time=274.6s + ttt_chunk [581/947] bpb=1.117345 time=279.4s + ttt_chunk [591/947] bpb=1.117903 time=284.2s + ttt_chunk [601/947] bpb=1.118879 time=289.0s + ttt_chunk [611/947] bpb=1.118156 time=293.9s + ttt_chunk [621/947] bpb=1.118238 time=298.7s + ttt_chunk [631/947] bpb=1.118326 time=303.5s + ttt_chunk [641/947] bpb=1.117463 time=308.3s + ttt_chunk [651/947] bpb=1.116603 time=313.1s + ttt_chunk [661/947] bpb=1.115978 time=317.9s + ttt_chunk [671/947] bpb=1.115895 time=322.7s + ttt_chunk [681/947] bpb=1.116368 time=327.5s + ttt_chunk [691/947] bpb=1.115889 time=332.3s + ttt_chunk [701/947] bpb=1.115125 time=337.1s + ttt_chunk [711/947] bpb=1.115060 time=341.9s + ttt_chunk [721/947] bpb=1.115545 time=346.8s + ttt_chunk [731/947] bpb=1.115121 time=351.6s + ttt_chunk [741/947] bpb=1.116324 time=356.4s + ttt_chunk [751/947] bpb=1.116353 time=361.2s + ttt_chunk [761/947] bpb=1.116609 time=366.1s + ttt_chunk [771/947] bpb=1.116482 time=370.9s + ttt_chunk [781/947] bpb=1.117179 time=375.7s + ttt_chunk [791/947] bpb=1.117690 time=380.5s + ttt_chunk [801/947] bpb=1.117796 time=385.4s + ttt_chunk [811/947] bpb=1.117691 time=390.2s + ttt_chunk [821/947] bpb=1.117622 time=395.0s + ttt_chunk [831/947] bpb=1.117797 time=399.8s + ttt_chunk [841/947] bpb=1.118673 time=404.6s + ttt_chunk [851/947] bpb=1.118387 time=409.4s + ttt_chunk [861/947] bpb=1.117767 time=414.2s + ttt_chunk [871/947] bpb=1.117492 time=419.0s + ttt_chunk [881/947] bpb=1.117660 time=423.8s + ttt_chunk [891/947] bpb=1.117556 time=428.7s + ttt_chunk [901/947] bpb=1.117377 time=433.5s + ttt_chunk [911/947] bpb=1.116984 time=438.3s + ttt_chunk [921/947] bpb=1.116482 time=443.1s + ttt_chunk [931/947] bpb=1.115801 time=447.9s + ttt_chunk [941/947] bpb=1.115244 time=452.7s + ttt_chunk [947/947] bpb=1.114969 time=455.2s +supermask_ttt:done val_loss=1.877469 val_bpb=1.111946 elapsed=455.2s +legal_ttt val_loss:1.8775 val_bpb:1.1119 eval_time:455769ms +legal_ttt_exact val_loss:1.87746873 val_bpb:1.11194568 diff --git a/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_999.log b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_999.log new file mode 100644 index 0000000000..f7194f9234 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_999.log @@ -0,0 +1,171 @@ +W0407 13:30:42.478000 67690 torch/distributed/run.py:803] +W0407 13:30:42.478000 67690 torch/distributed/run.py:803] ***************************************** +W0407 13:30:42.478000 67690 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0407 13:30:42.478000 67690 torch/distributed/run.py:803] ***************************************** +logs/f0eec0c9-fdbb-4ff3-a3ea-855f51bc136a.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27739277 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +attention_backend:flash_attn_3 +attention_mode:gqa num_heads:8 num_kv_heads:4 gated_attention:True +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.028 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:999 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9292 train_time:185ms step_avg:184.64ms +step:2/20000 train_loss:8.5453 train_time:270ms step_avg:135.17ms +step:3/20000 train_loss:7.6600 train_time:360ms step_avg:119.95ms +step:4/20000 train_loss:7.2892 train_time:450ms step_avg:112.38ms +step:5/20000 train_loss:6.9869 train_time:539ms step_avg:107.80ms +step:6/20000 train_loss:6.9148 train_time:628ms step_avg:104.69ms +step:7/20000 train_loss:6.8551 train_time:717ms step_avg:102.45ms +step:8/20000 train_loss:6.7030 train_time:806ms step_avg:100.79ms +step:9/20000 train_loss:6.4012 train_time:895ms step_avg:99.49ms +step:10/20000 train_loss:6.0383 train_time:985ms step_avg:98.46ms +step:500/20000 train_loss:2.3530 train_time:45721ms step_avg:91.44ms +step:1000/20000 train_loss:2.2415 train_time:91664ms step_avg:91.66ms +step:1500/20000 train_loss:2.1917 train_time:137531ms step_avg:91.69ms +step:2000/20000 train_loss:2.0374 train_time:183367ms step_avg:91.68ms +step:2500/20000 train_loss:2.1445 train_time:229143ms step_avg:91.66ms +step:3000/20000 train_loss:2.1392 train_time:274903ms step_avg:91.63ms +step:3500/20000 train_loss:2.1483 train_time:320637ms step_avg:91.61ms +step:4000/20000 train_loss:1.9419 train_time:366395ms step_avg:91.60ms +step:4500/20000 train_loss:2.0933 train_time:412125ms step_avg:91.58ms +step:5000/20000 train_loss:2.0694 train_time:457841ms step_avg:91.57ms +step:5500/20000 train_loss:1.9807 train_time:503552ms step_avg:91.55ms +step:6000/20000 train_loss:1.9050 train_time:549326ms step_avg:91.55ms +late_qat:enabled step:6029 scale:0.1499 +step:6500/20000 train_loss:2.0410 train_time:595013ms step_avg:91.54ms +step:6555/20000 val_loss:1.9127 val_bpb:1.1328 train_time:600021ms step_avg:91.54ms +stopping_early: wallclock_cap train_time:600021ms step:6555/20000 +peak memory allocated: 22185 MiB reserved: 22230 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9111 val_bpb:1.1319 eval_time:2112ms +Serialized model: 107548579 bytes +Code size: 61917 bytes +custom_serialize: raw=27888454 bytes +Serialized model int6+custom+: 15932974 bytes +Total submission size: 15994891 bytes +supermask_ttt:start chunks=947 chunk_tokens=65536 total_windows=969088 stride=64 ttt_lr=0.0008 ttt_epochs=4 +supermask_ttt:params mask_params=45056 (blocks=11 per_block=mlp_mask+mlp_bias+attn_mask+attn_bias) model_frozen=27739277 + ttt_chunk [1/947] bpb=1.154856 time=0.5s + ttt_chunk [11/947] bpb=1.122136 time=5.2s + ttt_chunk [21/947] bpb=1.094667 time=9.9s + ttt_chunk [31/947] bpb=1.106913 time=14.7s + ttt_chunk [41/947] bpb=1.101062 time=19.4s + ttt_chunk [51/947] bpb=1.112149 time=24.1s + ttt_chunk [61/947] bpb=1.107153 time=28.8s + ttt_chunk [71/947] bpb=1.107537 time=33.5s + ttt_chunk [81/947] bpb=1.114742 time=38.2s + ttt_chunk [91/947] bpb=1.117779 time=42.9s + ttt_chunk [101/947] bpb=1.118831 time=47.6s + ttt_chunk [111/947] bpb=1.118242 time=52.3s + ttt_chunk [121/947] bpb=1.117427 time=57.1s + ttt_chunk [131/947] bpb=1.112978 time=61.8s + ttt_chunk [141/947] bpb=1.114770 time=66.5s + ttt_chunk [151/947] bpb=1.116316 time=71.2s + ttt_chunk [161/947] bpb=1.120299 time=75.9s + ttt_chunk [171/947] bpb=1.120670 time=80.6s + ttt_chunk [181/947] bpb=1.124765 time=85.4s + ttt_chunk [191/947] bpb=1.125712 time=90.1s + ttt_chunk [201/947] bpb=1.123931 time=94.8s + ttt_chunk [211/947] bpb=1.122122 time=99.5s + ttt_chunk [221/947] bpb=1.121577 time=104.2s + ttt_chunk [231/947] bpb=1.122027 time=108.9s + ttt_chunk [241/947] bpb=1.123901 time=113.6s + ttt_chunk [251/947] bpb=1.122471 time=118.3s + ttt_chunk [261/947] bpb=1.120523 time=123.1s + ttt_chunk [271/947] bpb=1.121827 time=127.8s + ttt_chunk [281/947] bpb=1.119946 time=132.5s + ttt_chunk [291/947] bpb=1.119534 time=137.2s + ttt_chunk [301/947] bpb=1.119550 time=141.9s + ttt_chunk [311/947] bpb=1.120508 time=146.6s + ttt_chunk [321/947] bpb=1.121332 time=151.4s + ttt_chunk [331/947] bpb=1.120735 time=156.1s + ttt_chunk [341/947] bpb=1.121366 time=160.8s + ttt_chunk [351/947] bpb=1.122002 time=165.5s + ttt_chunk [361/947] bpb=1.122177 time=170.2s + ttt_chunk [371/947] bpb=1.122305 time=175.0s + ttt_chunk [381/947] bpb=1.121563 time=179.7s + ttt_chunk [391/947] bpb=1.120929 time=184.4s + ttt_chunk [401/947] bpb=1.120573 time=189.1s + ttt_chunk [411/947] bpb=1.119646 time=193.8s + ttt_chunk [421/947] bpb=1.118672 time=198.5s + ttt_chunk [431/947] bpb=1.119493 time=203.3s + ttt_chunk [441/947] bpb=1.118746 time=208.0s + ttt_chunk [451/947] bpb=1.118315 time=212.7s + ttt_chunk [461/947] bpb=1.118524 time=217.5s + ttt_chunk [471/947] bpb=1.117953 time=222.2s + ttt_chunk [481/947] bpb=1.117436 time=227.0s + ttt_chunk [491/947] bpb=1.118098 time=231.7s + ttt_chunk [501/947] bpb=1.117928 time=236.5s + ttt_chunk [511/947] bpb=1.118269 time=241.2s + ttt_chunk [521/947] bpb=1.118537 time=245.9s + ttt_chunk [531/947] bpb=1.118391 time=250.7s + ttt_chunk [541/947] bpb=1.117946 time=255.4s + ttt_chunk [551/947] bpb=1.118566 time=260.1s + ttt_chunk [561/947] bpb=1.118021 time=264.9s + ttt_chunk [571/947] bpb=1.118437 time=269.6s + ttt_chunk [581/947] bpb=1.117133 time=274.3s + ttt_chunk [591/947] bpb=1.117746 time=279.0s + ttt_chunk [601/947] bpb=1.118735 time=283.7s + ttt_chunk [611/947] bpb=1.118023 time=288.5s + ttt_chunk [621/947] bpb=1.118083 time=293.3s + ttt_chunk [631/947] bpb=1.118186 time=298.0s + ttt_chunk [641/947] bpb=1.117318 time=302.7s + ttt_chunk [651/947] bpb=1.116463 time=307.5s + ttt_chunk [661/947] bpb=1.115834 time=312.2s + ttt_chunk [671/947] bpb=1.115735 time=317.0s + ttt_chunk [681/947] bpb=1.116208 time=321.7s + ttt_chunk [691/947] bpb=1.115728 time=326.4s + ttt_chunk [701/947] bpb=1.115009 time=331.2s + ttt_chunk [711/947] bpb=1.114946 time=335.9s + ttt_chunk [721/947] bpb=1.115438 time=340.6s + ttt_chunk [731/947] bpb=1.115002 time=345.4s + ttt_chunk [741/947] bpb=1.116200 time=350.1s + ttt_chunk [751/947] bpb=1.116217 time=354.9s + ttt_chunk [761/947] bpb=1.116474 time=359.6s + ttt_chunk [771/947] bpb=1.116355 time=364.4s + ttt_chunk [781/947] bpb=1.117046 time=369.1s + ttt_chunk [791/947] bpb=1.117594 time=373.9s + ttt_chunk [801/947] bpb=1.117725 time=378.6s + ttt_chunk [811/947] bpb=1.117634 time=383.4s + ttt_chunk [821/947] bpb=1.117557 time=388.1s + ttt_chunk [831/947] bpb=1.117726 time=392.8s + ttt_chunk [841/947] bpb=1.118599 time=397.6s + ttt_chunk [851/947] bpb=1.118325 time=402.3s + ttt_chunk [861/947] bpb=1.117719 time=407.0s + ttt_chunk [871/947] bpb=1.117436 time=411.8s + ttt_chunk [881/947] bpb=1.117594 time=416.5s + ttt_chunk [891/947] bpb=1.117499 time=421.2s + ttt_chunk [901/947] bpb=1.117337 time=425.9s + ttt_chunk [911/947] bpb=1.116962 time=430.6s + ttt_chunk [921/947] bpb=1.116473 time=435.3s + ttt_chunk [931/947] bpb=1.115800 time=439.9s + ttt_chunk [941/947] bpb=1.115242 time=444.6s + ttt_chunk [947/947] bpb=1.114975 time=447.1s +supermask_ttt:done val_loss=1.877871 val_bpb=1.112184 elapsed=447.1s +legal_ttt val_loss:1.8779 val_bpb:1.1122 eval_time:447622ms +legal_ttt_exact val_loss:1.87787103 val_bpb:1.11218395 diff --git a/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_gpt.py b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_gpt.py new file mode 100644 index 0000000000..58dbdadd27 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_gpt.py @@ -0,0 +1,1484 @@ +import copy +import glob +import json +import math +import os +import random +import struct +import time +import uuid +from pathlib import Path +import brotli +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as _fa3_func +def flash_attn_3_func(q, k, v, causal=True): + return _fa3_func(q, k, v, causal=causal) +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_bsz = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.028)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + muon_wd = float(os.environ.get("MUON_WD", 0.0417)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "1"))) + ve_layers = os.environ.get("VE_LAYERS", "5,7,10") + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.0008)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 4)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 65536)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_bsz // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_bsz}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class ProductKeyEmbed(nn.Module): + def __init__(self, model_dim: int): + super().__init__() + self.embed_prev = nn.Embedding(1024, model_dim) + self.embed_cur = nn.Embedding(1024, model_dim) + nn.init.normal_(self.embed_prev.weight, std=0.02) + nn.init.normal_(self.embed_cur.weight, std=0.02) + self.proj = None + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + t = token_ids.long() + t_prev = torch.cat([torch.zeros_like(t[..., :1]), t[..., :-1]], dim=-1) + h = self.embed_prev(t_prev) * self.embed_cur(t) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + gated_attention: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, gated_attention=gated_attention) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + return x_out +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = ProductKeyEmbed(model_dim) + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + _gated_attn_layers = set(range(0, num_layers, 2)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + gated_attention=gated_attention and (i in _gated_attn_layers), + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +class MuonLiteTTT(torch.optim.Optimizer): + def __init__(self, params, lr=0.005, momentum=0.9, weight_decay=0.0): + defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) + super().__init__(params, defaults) + @torch.no_grad() + def step(self, closure=None): + for group in self.param_groups: + lr = group['lr'] + momentum = group['momentum'] + wd = group['weight_decay'] + for p in group['params']: + if p.grad is None: + continue + g = p.grad + state = self.state[p] + if 'momentum_buffer' not in state: + state['momentum_buffer'] = torch.zeros_like(g) + buf = state['momentum_buffer'] + buf.mul_(momentum).add_(g) + g = g.add(buf, alpha=momentum) + update = g.sign() + update *= (p.numel() ** 0.5) / max(update.norm(), 1e-8) + if wd > 0: + p.mul_(1 - lr * wd) + p.add_(update, alpha=-lr) +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + for p in base_model.parameters(): + p.requires_grad_(False) + mask_scores: list[Tensor] = [] + original_mlp_forwards: list = [] + original_attn_forwards: list = [] + for block in base_model.blocks: + hidden_dim = block.mlp.fc.weight.shape[0] + attn_dim = block.attn.proj.weight.shape[0] + mlp_ms = torch.full((hidden_dim,), 3.0, device=device, requires_grad=True) + mlp_bias = torch.zeros(hidden_dim, device=device, requires_grad=True) + attn_ms = torch.full((attn_dim,), 3.0, device=device, requires_grad=True) + attn_bias = torch.zeros(attn_dim, device=device, requires_grad=True) + mask_scores.extend([mlp_ms, mlp_bias, attn_ms, attn_bias]) + original_mlp_forwards.append(block.mlp.forward) + original_attn_forwards.append(block.attn.forward) + for bi, block in enumerate(base_model.blocks): + mlp_ms = mask_scores[bi * 4] + mlp_bias = mask_scores[bi * 4 + 1] + attn_ms = mask_scores[bi * 4 + 2] + attn_bias = mask_scores[bi * 4 + 3] + fc_layer = block.mlp.fc + proj_layer = block.mlp.proj + def _make_masked_mlp_fwd(fc, proj, ms, bias): + def masked_forward(x): + h = F.leaky_relu(fc(x), negative_slope=0.5) + h = h.square() + mask = torch.sigmoid(ms).to(h.dtype) + h = h * mask + bias.to(h.dtype) + return proj(h) + return masked_forward + block.mlp.forward = _make_masked_mlp_fwd(fc_layer, proj_layer, mlp_ms, mlp_bias) + orig_attn_fwd = block.attn.forward.__func__ if hasattr(block.attn.forward, '__func__') else None + def _make_masked_attn_fwd(attn_module, ms, bias, orig_fwd): + def masked_attn_forward(x, v_embed=None): + out = orig_fwd(attn_module, x, v_embed=v_embed) + mask = torch.sigmoid(ms).to(out.dtype) + return out * mask + bias.to(out.dtype) + return masked_attn_forward + if orig_attn_fwd is not None: + block.attn.forward = _make_masked_attn_fwd(block.attn, attn_ms, attn_bias, orig_attn_fwd) + total_mask_params = sum(ms.numel() for ms in mask_scores) + total_frozen = sum(p.numel() for p in base_model.parameters()) + num_blocks = len(base_model.blocks) + log0(f"supermask_ttt:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs}") + log0(f"supermask_ttt:params mask_params={total_mask_params} " + f"(blocks={num_blocks} per_block=mlp_mask+mlp_bias+attn_mask+attn_bias) model_frozen={total_frozen}") + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + optimizer = MuonLiteTTT(mask_scores, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + frac = ci / max(num_chunks - 1, 1) + if frac < 0.05: + cos_lr = args.ttt_lr * (frac / 0.05) + elif frac < 0.70: + cos_lr = args.ttt_lr + else: + decay_frac = (frac - 0.70) / 0.30 + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * decay_frac)) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if dist.is_available() and dist.is_initialized(): + for ms in mask_scores: + if ms.grad is not None: + dist.all_reduce(ms.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(mask_scores, args.ttt_grad_clip) + optimizer.step() + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + for block, orig_mlp_fwd, orig_attn_fwd in zip(base_model.blocks, original_mlp_forwards, original_attn_forwards): + block.mlp.forward = orig_mlp_fwd + block.attn.forward = orig_attn_fwd + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"supermask_ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb +_DTYPE_TO_CODE = {torch.int8: 0, torch.float16: 1, torch.float32: 2, torch.uint8: 3} +_CODE_TO_NP = {0: np.int8, 1: np.float16, 2: np.float32, 3: np.uint8} +def _custom_serialize(quant_result: dict, quant_meta: dict) -> bytes: + names = sorted(quant_result.keys()) + index = [] + parts = [] + off = 0 + for n in names: + t = quant_result[n] + a = t.numpy() if not t.is_cuda else t.cpu().numpy() + dc = _DTYPE_TO_CODE.get(t.dtype, 2) + raw = a.tobytes() + index.append((n, dc, list(a.shape), off, len(raw))) + parts.append(raw) + off += len(raw) + mb = json.dumps(quant_meta, separators=(',', ':')).encode() + ib = json.dumps(index, separators=(',', ':')).encode() + return struct.pack(' tuple[dict, dict]: + ml, il = struct.unpack(' str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def _scale_to_u8(s_fp16: Tensor) -> tuple[Tensor, Tensor]: + s = s_fp16.float() + s_min = s.min().clamp_min(1e-10) + s_max = s.max().clamp_min(1e-10) + log_min, log_max = s_min.log(), s_max.log() + if log_max - log_min < 1e-6: + return torch.full_like(s, 128, dtype=torch.uint8), torch.stack([s_min.to(torch.float16), s_max.to(torch.float16)]) + norm = ((s.log() - log_min) / (log_max - log_min) * 255).round().clamp(0, 255).to(torch.uint8) + return norm, torch.stack([s_min.to(torch.float16), s_max.to(torch.float16)]) +def _u8_to_scale(norm: Tensor, base: Tensor) -> Tensor: + s_min, s_max = base[0].float(), base[1].float() + return torch.exp(s_min.log() + (norm.float() / 255.0) * (s_max.log() - s_min.log())).to(torch.float16) +def gen_ar_calib(model, device, ns=64, seq_len=2048, + vocab_size=1024, temp=0.8, bsz=8, seed=42): + model.eval() + rng = torch.Generator(device=device) + rng.manual_seed(seed) + all_tokens = [] + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for batch_start in range(0, ns, bsz): + bs = min(bsz, ns - batch_start) + tokens = torch.randint(0, vocab_size, (bs, 1), device=device, generator=rng) + for pos in range(seq_len - 1): + logits = model.forward_logits(tokens) + next_logit = logits[:, -1, :] + probs = torch.softmax(next_logit / temp, dim=-1) + next_tok = torch.multinomial(probs, 1, generator=rng) + tokens = torch.cat([tokens, next_tok], dim=1) + for i in range(bs): + all_tokens.append(tokens[i:i+1]) + return all_tokens +def collect_hessians(hm, tseqs, device): + hessians = {} + hooks = [] + for name, module in hm.named_modules(): + if isinstance(module, CastedLinear): + param_name = name + ".weight" + cols = module.weight.shape[1] + hessians[param_name] = torch.zeros(cols, cols, dtype=torch.float32, device='cpu') + def make_hook(pname): + def hook_fn(module, input, output): + x = input[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + hessians[pname] += (x.T @ x).cpu() + return hook_fn + h = module.register_forward_hook(make_hook(param_name)) + hooks.append(h) + hm.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for seq in tseqs: + x = seq[:, :-1].to(device) + y = seq[:, 1:].to(device) + hm(x, y) + for h in hooks: + h.remove() + num_batches = len(tseqs) + for name in hessians: + H = hessians[name] + H /= num_batches + damp = 0.01 * torch.diag(H).mean().clamp_min(1e-6) + H += damp * torch.eye(H.shape[0]) + hessians[name] = H + return hessians +def quantize_int6_gptq(weight, hessian=None, cr=31, blksz=128): + t32 = weight.float() + if t32.ndim != 2 or hessian is None: + amax = t32.abs().max().item() + scale = torch.tensor(amax / cr if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -cr, cr).to(torch.int8) + return q, scale + rows, cols = t32.shape + H = hessian.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * torch.mean(torch.diag(H)) + H[torch.arange(cols), torch.arange(cols)] += damp + perm = torch.argsort(torch.diag(H), descending=True) + inv_perm = torch.argsort(perm) + W = t32[:, perm].clone() + W[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.linalg.cholesky(H) + Hinv = torch.cholesky_inverse(Hinv) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + best_q = None; best_scale = None; best_err = float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / cr).clamp_min(1.0 / cr).to(torch.float16) + sf = s.float() + Q = torch.zeros_like(W, dtype=torch.int8) + W_work = W.clone() + for i1 in range(0, cols, blksz): + i2 = min(i1 + blksz, cols) + count = i2 - i1 + W1 = W_work[:, i1:i2].clone() + Q1 = torch.zeros(rows, count, dtype=torch.int8) + Err1 = torch.zeros(rows, count) + Hinv1 = Hinv[i1:i2, i1:i2] + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = torch.clamp(torch.round(w / sf), -cr, cr).to(torch.int8) + Q1[:, i] = q + err = (w - q.float() * sf) / d + W1[:, i:] -= err.unsqueeze(1) * Hinv1[i, i:].unsqueeze(0) + Err1[:, i] = err + Q[:, i1:i2] = Q1 + if i2 < cols: + W_work[:, i2:] -= Err1 @ Hinv[i1:i2, i2:] + recon = Q.float() * sf[:, None] + mse = (W - recon).pow(2).mean().item() + if mse < best_err: + best_q, best_scale, best_err = Q, s, mse + best_q = best_q[:, inv_perm] + return best_q, best_scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], hessians=None): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "p" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "c" + continue + if cat in int6_cats and t.ndim >= 1: + if hessians is not None and name in hessians: + q, s = quantize_int6_gptq(t, hessians[name]) + else: + q, s = quantize_int6_gptq(t) + result[name + ".q"] = q + if s.ndim > 0 and s.numel() > 1: + su8, sbase = _scale_to_u8(s) + result[name + ".su8"] = su8 + result[name + ".sbase"] = sbase + meta[name] = "6u" + else: + result[name + ".scale"] = s + meta[name] = "6" + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = "8" + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("p", "c", "passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + if info == "6u": + q = result[name + ".q"] + s = _u8_to_scale(result[name + ".su8"], result[name + ".sbase"]) + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = False + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + bigram_embed_params = [] + for n, p in base_model.bigram.named_parameters(): + if 'embed' in n and 'weight' in n: + bigram_embed_params.append(p) + if bigram_embed_params: + tok_params.append({"params": bigram_embed_params, "lr": token_lr, "base_lr": token_lr}) + if hasattr(base_model.bigram, 'proj') and base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("attention_backend:flash_attn_3") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads} gated_attention:{args.gated_attention}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if scale < 0.15 and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = full_state_dict + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + ctoks = gen_ar_calib(base_model, device, ns=64, seq_len=2048, + vocab_size=args.vocab_size, temp=0.8, bsz=8, seed=42) + hessians = collect_hessians(base_model, ctoks, device) + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, hessians=hessians) + quant_raw = _custom_serialize(quant_result, quant_meta) + log0(f"custom_serialize: raw={len(quant_raw)} bytes") + quant_blob = brotli.compress(quant_raw, quality=11) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+custom+{brotli}: {quant_file_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + _decompressed = brotli.decompress(quant_blob_disk) + quant_result_loaded, quant_meta_loaded = _custom_deserialize(_decompressed) + deq_state = dequantize_mixed_int6(quant_result_loaded, quant_meta_loaded, sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main()