Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# 11L LatentMask TTT + GPTQ + Product-Key Bigram + Brotli

**val_bpb: 1.1124 (3-seed mean)** | Artifact: ≤15,994,891 bytes | 8xH100, 600s training + ~455s eval

## Summary

11-layer GPT with U-Net skip connections, achieving 1.1124 val_bpb (3-seed mean) through four key innovations:

1. **LatentMask TTT (Test-Time Training)**: Per-channel sigmoid masks + biases on MLP and attention outputs, trained per-chunk during evaluation using a sign-based Muon-lite optimizer. Score-first (legal): each chunk is scored before any mask update. Provides ~−0.002 bpb improvement over sliding window eval (without TTT).

2. **Full Hessian GPTQ**: Hessian-aware int6 quantization with Cholesky error compensation and column reordering. Uses autoregressive self-generated calibration data (32 sequences × 2048 tokens). Reduces quantization error vs per-row percentile search.

3. **Product-Key Bigram Embedding**: Factored bigram via `embed_prev(1024,512) * embed_cur(1024,512)` — zero hash collision, no projection layer needed. Replaces traditional hash-based bigram embedding.

4. **Brotli-11 Compression**: Custom binary serialization (JSON header + raw tensor bytes) compressed with Brotli quality=11. Combined with uint8 log-scale quantization for per-row scales.

## Architecture

- 11 layers, 512 dim, 8 heads, 4 KV heads (GQA)
- MLP 3x (1536 hidden), LeakyReLU(0.5)²
- GatedAttention on even layers [0,2,4,6,8,10]
- XSA (Exclusive Self-Attention) on all 11 layers
- Value Embeddings at decoder layers [5,7,10]
- U-Net encoder-decoder skip connections
- SmearGate for adjacent token mixing
- Tied embeddings, logit softcap=30

## Training

- Muon optimizer (matrix params), AdamW (scalar/embed params)
- matrix_lr=0.028, muon_wd=0.0417
- EMA (decay=0.997), Late QAT (threshold=0.15)
- Warmdown: 3500 steps (time-based adaptive)
- ~6,555 steps in 600s, step_avg ~91.5ms (H100, Flash Attention 3)

## Quantization & Compression

- int6 per-row for MLP/attention weights (GPTQ when Hessian available)
- int8 per-row for embeddings
- uint8 log-scale for per-row scales (2B → 1B per scale)
- Custom binary serialization + Brotli-11

## Evaluation

- LatentMask TTT: lr=0.0008, chunk=65536 tokens, epochs=4, momentum=0.9
- Sliding window stride=64, seq_len=2048
- TTT eval time: ~455s (H100)

## Results (3-seed, 8xH100)

| Seed | val_bpb | Steps | Step Avg (ms) | Artifact (bytes) |
|------|---------|-------|---------------|-------------------|
| 777 | 1.11195 | 6,555 | 91.54 | 15,985,742 |
| 999 | 1.11218 | 6,555 | 91.54 | 15,994,891 |
| 1337 | 1.11297 | 6,556 | 91.52 | 15,988,042 |
| **Mean** | **1.11237** | **6,555** | **91.53** | |

## Dependencies

```
pip install flash-attn-3 brotli
```

Flash Attention 3 (Hopper kernels) is required. The script imports `flash_attn_interface` directly.

(`sentencepiece`, `torch`, `numpy` assumed pre-installed)

## Run

```bash
torchrun --standalone --nproc_per_node=8 train_gpt.py
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
flash-attn
brotli
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"author": "JungyupLee",
"github_id": "izlley",
"name": "LatentMask TTT + Full Hessian GPTQ + Product-Key Bigram + Brotli",
"blurb": "11L GPT with U-Net skips, LatentMask TTT (per-channel sigmoid masks, score-first legal), Full Hessian GPTQ (AR self-gen calib, Cholesky error compensation), Product-Key Bigram (zero collision), GatedAttention on even layers, XSA on all layers, Value Embeddings at decoder [5,7,10], Muon optimizer, EMA+SWA, Late QAT, Brotli-11 + uint8 log-scale compression.",
"date": "2026-04-05",
"track": "10min_16mb",
"val_loss": 1.87817944,
"val_bpb": 1.11236660,
"val_loss_std": 0.00073910,
"val_bpb_std": 0.00043774,
"seeds": [777, 999, 1337],
"seed_results": {
"777": {
"val_loss": 1.87746873,
"val_bpb": 1.11194568,
"artifact_bytes": 15985742,
"steps": 6555,
"step_avg_ms": 91.54
},
"999": {
"val_loss": 1.87787103,
"val_bpb": 1.11218395,
"artifact_bytes": 15994891,
"steps": 6555,
"step_avg_ms": 91.54
},
"1337": {
"val_loss": 1.87919855,
"val_bpb": 1.11297018,
"artifact_bytes": 15988042,
"steps": 6556,
"step_avg_ms": 91.52
}
},
"artifact_bytes_max": 15994891,
"bytes_total": 15994891,
"code_bytes": 61917,
"model_bytes": 15932974,
"train_steps_mean": 6555.3,
"step_avg_ms_mean": 91.53,
"hardware": "8xH100 80GB SXM (RunPod)",
"flash_attn_version": "3.0.0+20260303.cu128torch291cxx11abitrue.ceb109",
"technique_summary": "LatentMask TTT + Full Hessian GPTQ (AR self-gen) + Product-Key Bigram + GatedAttn-even + XSA-all + VE-decoder + Muon + Brotli-11"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
W0407 13:54:02.290000 68766 torch/distributed/run.py:803]
W0407 13:54:02.290000 68766 torch/distributed/run.py:803] *****************************************
W0407 13:54:02.290000 68766 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0407 13:54:02.290000 68766 torch/distributed/run.py:803] *****************************************
logs/4efa9d48-88d4-4ff4-aff3-f3fa722476d7.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/parameter-golf/data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=/workspace/parameter-golf/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:27739277
XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
world_size:8 grad_accum_steps:1
attention_backend:flash_attn_3
attention_mode:gqa num_heads:8 num_kv_heads:4 gated_attention:True
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.028 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:1/20000 train_loss:6.9313 train_time:149ms step_avg:149.01ms
step:2/20000 train_loss:8.7077 train_time:235ms step_avg:117.28ms
step:3/20000 train_loss:7.7451 train_time:324ms step_avg:108.03ms
step:4/20000 train_loss:7.1697 train_time:414ms step_avg:103.47ms
step:5/20000 train_loss:6.9473 train_time:503ms step_avg:100.66ms
step:6/20000 train_loss:6.8905 train_time:593ms step_avg:98.76ms
step:7/20000 train_loss:6.8465 train_time:682ms step_avg:97.41ms
step:8/20000 train_loss:6.7056 train_time:772ms step_avg:96.47ms
step:9/20000 train_loss:6.4039 train_time:861ms step_avg:95.61ms
step:10/20000 train_loss:6.0594 train_time:950ms step_avg:95.01ms
step:500/20000 train_loss:2.3535 train_time:45721ms step_avg:91.44ms
step:1000/20000 train_loss:2.2504 train_time:91556ms step_avg:91.56ms
step:1500/20000 train_loss:2.1935 train_time:137382ms step_avg:91.59ms
step:2000/20000 train_loss:2.0431 train_time:183157ms step_avg:91.58ms
step:2500/20000 train_loss:2.1503 train_time:228933ms step_avg:91.57ms
step:3000/20000 train_loss:2.1441 train_time:274696ms step_avg:91.57ms
step:3500/20000 train_loss:2.1507 train_time:320508ms step_avg:91.57ms
step:4000/20000 train_loss:1.9396 train_time:366267ms step_avg:91.57ms
step:4500/20000 train_loss:2.0902 train_time:411981ms step_avg:91.55ms
step:5000/20000 train_loss:2.0725 train_time:457729ms step_avg:91.55ms
step:5500/20000 train_loss:1.9867 train_time:503453ms step_avg:91.54ms
step:6000/20000 train_loss:1.9023 train_time:549179ms step_avg:91.53ms
late_qat:enabled step:6031 scale:0.1499
step:6500/20000 train_loss:2.0427 train_time:594883ms step_avg:91.52ms
step:6556/20000 val_loss:1.9138 val_bpb:1.1335 train_time:599973ms step_avg:91.52ms
stopping_early: wallclock_cap train_time:599973ms step:6556/20000
peak memory allocated: 22185 MiB reserved: 22230 MiB
ema:applying EMA weights
DIAGNOSTIC post_ema val_loss:1.9122 val_bpb:1.1325 eval_time:2117ms
Serialized model: 107548579 bytes
Code size: 61917 bytes
custom_serialize: raw=27888454 bytes
Serialized model int6+custom+<module 'brotli' from '/usr/local/lib/python3.12/dist-packages/brotli.py'>: 15926125 bytes
Total submission size: 15988042 bytes
supermask_ttt:start chunks=947 chunk_tokens=65536 total_windows=969088 stride=64 ttt_lr=0.0008 ttt_epochs=4
supermask_ttt:params mask_params=45056 (blocks=11 per_block=mlp_mask+mlp_bias+attn_mask+attn_bias) model_frozen=27739277
ttt_chunk [1/947] bpb=1.159986 time=0.5s
ttt_chunk [11/947] bpb=1.126072 time=5.3s
ttt_chunk [21/947] bpb=1.097464 time=10.1s
ttt_chunk [31/947] bpb=1.108908 time=14.8s
ttt_chunk [41/947] bpb=1.102337 time=19.6s
ttt_chunk [51/947] bpb=1.113493 time=24.3s
ttt_chunk [61/947] bpb=1.108278 time=29.1s
ttt_chunk [71/947] bpb=1.108788 time=33.8s
ttt_chunk [81/947] bpb=1.115921 time=38.6s
ttt_chunk [91/947] bpb=1.118991 time=43.3s
ttt_chunk [101/947] bpb=1.120177 time=48.1s
ttt_chunk [111/947] bpb=1.119387 time=52.9s
ttt_chunk [121/947] bpb=1.118636 time=57.6s
ttt_chunk [131/947] bpb=1.114143 time=62.4s
ttt_chunk [141/947] bpb=1.115884 time=67.1s
ttt_chunk [151/947] bpb=1.117421 time=71.9s
ttt_chunk [161/947] bpb=1.121499 time=76.6s
ttt_chunk [171/947] bpb=1.121852 time=81.3s
ttt_chunk [181/947] bpb=1.125681 time=86.1s
ttt_chunk [191/947] bpb=1.126671 time=90.8s
ttt_chunk [201/947] bpb=1.124873 time=95.5s
ttt_chunk [211/947] bpb=1.122965 time=100.3s
ttt_chunk [221/947] bpb=1.122447 time=105.0s
ttt_chunk [231/947] bpb=1.122880 time=109.7s
ttt_chunk [241/947] bpb=1.124736 time=114.5s
ttt_chunk [251/947] bpb=1.123316 time=119.2s
ttt_chunk [261/947] bpb=1.121324 time=123.9s
ttt_chunk [271/947] bpb=1.122645 time=128.7s
ttt_chunk [281/947] bpb=1.120706 time=133.4s
ttt_chunk [291/947] bpb=1.120303 time=138.2s
ttt_chunk [301/947] bpb=1.120288 time=142.9s
ttt_chunk [311/947] bpb=1.121292 time=147.6s
ttt_chunk [321/947] bpb=1.122122 time=152.4s
ttt_chunk [331/947] bpb=1.121531 time=157.1s
ttt_chunk [341/947] bpb=1.122221 time=161.8s
ttt_chunk [351/947] bpb=1.122854 time=166.6s
ttt_chunk [361/947] bpb=1.123091 time=171.3s
ttt_chunk [371/947] bpb=1.123278 time=176.1s
ttt_chunk [381/947] bpb=1.122532 time=180.8s
ttt_chunk [391/947] bpb=1.121911 time=185.6s
ttt_chunk [401/947] bpb=1.121561 time=190.3s
ttt_chunk [411/947] bpb=1.120589 time=195.1s
ttt_chunk [421/947] bpb=1.119598 time=199.8s
ttt_chunk [431/947] bpb=1.120465 time=204.6s
ttt_chunk [441/947] bpb=1.119705 time=209.3s
ttt_chunk [451/947] bpb=1.119311 time=214.1s
ttt_chunk [461/947] bpb=1.119551 time=218.8s
ttt_chunk [471/947] bpb=1.118941 time=223.5s
ttt_chunk [481/947] bpb=1.118406 time=228.3s
ttt_chunk [491/947] bpb=1.119079 time=233.0s
ttt_chunk [501/947] bpb=1.118907 time=237.7s
ttt_chunk [511/947] bpb=1.119250 time=242.5s
ttt_chunk [521/947] bpb=1.119526 time=247.3s
ttt_chunk [531/947] bpb=1.119334 time=252.0s
ttt_chunk [541/947] bpb=1.118905 time=256.7s
ttt_chunk [551/947] bpb=1.119547 time=261.5s
ttt_chunk [561/947] bpb=1.118974 time=266.2s
ttt_chunk [571/947] bpb=1.119412 time=271.0s
ttt_chunk [581/947] bpb=1.118122 time=275.7s
ttt_chunk [591/947] bpb=1.118695 time=280.4s
ttt_chunk [601/947] bpb=1.119685 time=285.2s
ttt_chunk [611/947] bpb=1.118961 time=289.9s
ttt_chunk [621/947] bpb=1.119009 time=294.6s
ttt_chunk [631/947] bpb=1.119142 time=299.3s
ttt_chunk [641/947] bpb=1.118264 time=304.1s
ttt_chunk [651/947] bpb=1.117396 time=308.8s
ttt_chunk [661/947] bpb=1.116765 time=313.6s
ttt_chunk [671/947] bpb=1.116670 time=318.3s
ttt_chunk [681/947] bpb=1.117126 time=323.0s
ttt_chunk [691/947] bpb=1.116652 time=327.8s
ttt_chunk [701/947] bpb=1.115895 time=332.5s
ttt_chunk [711/947] bpb=1.115840 time=337.3s
ttt_chunk [721/947] bpb=1.116325 time=342.0s
ttt_chunk [731/947] bpb=1.115883 time=346.7s
ttt_chunk [741/947] bpb=1.117077 time=351.5s
ttt_chunk [751/947] bpb=1.117106 time=356.2s
ttt_chunk [761/947] bpb=1.117370 time=360.9s
ttt_chunk [771/947] bpb=1.117251 time=365.7s
ttt_chunk [781/947] bpb=1.117942 time=370.4s
ttt_chunk [791/947] bpb=1.118484 time=375.2s
ttt_chunk [801/947] bpb=1.118591 time=379.9s
ttt_chunk [811/947] bpb=1.118499 time=384.6s
ttt_chunk [821/947] bpb=1.118419 time=389.4s
ttt_chunk [831/947] bpb=1.118590 time=394.1s
ttt_chunk [841/947] bpb=1.119483 time=398.8s
ttt_chunk [851/947] bpb=1.119201 time=403.6s
ttt_chunk [861/947] bpb=1.118602 time=408.3s
ttt_chunk [871/947] bpb=1.118322 time=413.1s
ttt_chunk [881/947] bpb=1.118486 time=417.8s
ttt_chunk [891/947] bpb=1.118402 time=422.6s
ttt_chunk [901/947] bpb=1.118221 time=427.3s
ttt_chunk [911/947] bpb=1.117833 time=432.1s
ttt_chunk [921/947] bpb=1.117337 time=436.8s
ttt_chunk [931/947] bpb=1.116661 time=441.6s
ttt_chunk [941/947] bpb=1.116105 time=446.3s
ttt_chunk [947/947] bpb=1.115820 time=448.8s
supermask_ttt:done val_loss=1.879199 val_bpb=1.112970 elapsed=448.8s
legal_ttt val_loss:1.8792 val_bpb:1.1130 eval_time:449323ms
legal_ttt_exact val_loss:1.87919855 val_bpb:1.11297018
Loading