diff --git a/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/README.md b/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/README.md new file mode 100644 index 0000000000..20049ae0ee --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/README.md @@ -0,0 +1,140 @@ +# Pre-Quant TTT + ETLB: Eval-Time Logit Bias for Neural Language Model Compression + +## Summary + +**3-seed mean BPB: 1.0898 (std: 0.0008)** + +This submission introduces **Eval-Time Logit Bias (ETLB)**, a novel eval-time augmentation technique that optimizes a warm-started vocabulary bias vector during sliding window evaluation. Combined with pre-quantization test-time training (Pre-Quant TTT), this achieves a new best pure neural BPB on the 10-minute 16MB track. + +Built on PR #1285's architecture (MuonEq-R + Depth Recurrence + All-Int6 GPTQ). + +## Results + +| Seed | Sliding BPB | ETLB BPB | Artifact Size | Fits? | +|------|------------|----------|---------------|-------| +| 1337 | 1.0916 | **1.0897** | 16,084,685 bytes | ✅ | +| 42 | 1.0926 | **1.0906** | 16,092,287 bytes | ✅ | +| 2025 | 1.0908 | **1.0891** | 16,087,467 bytes | ✅ | +| **Mean** | 1.0917 | **1.0898** | | ✅ | +| **Std** | 0.0009 | **0.0008** | | | + +Hardware: 8×H100 SXM, ~5,500 steps in 600s, tok/s ~7,800+ + +## Novel Techniques + +### 1. Pre-Quantization Test-Time Training (Pre-Quant TTT) + +Adapts the full-precision EMA model weights on validation data **before** GPTQ quantization. The adapted weights are baked into the artifact — no eval-time overhead. + +- **Freeze:** First 9 of 11 blocks frozen, last 2 blocks adapted +- **Optimizer:** AdamW, lr=0.0005 +- **Data:** Validation chunks (32768 tokens), 1 epoch +- **Trainable params:** 5.77M / 34.4M total +- **Time:** ~112s (fits within the 10-minute budget) +- **Score-first compliant:** Each chunk is scored under `inference_mode()` before being used for training + +### 2. Eval-Time Logit Bias (ETLB) — *Novel* + +During sliding window evaluation, ETLB optimizes a bias vector `b ∈ ℝ^vocab` added to output logits. The bias captures document-level token frequency patterns and adapts the model's output distribution to the local context. + +**Algorithm:** +``` +Initialize b = zeros(vocab_size) +For each sliding window: + 1. Forward pass → logits (frozen model, no gradient) + 2. Split window into context tokens (already scored) and stride tokens (to be scored) + 3. Optimize b on context tokens via SGD (5 steps, lr=0.05) + - Loss: cross-entropy(logits[context] + b, targets[context]) + 4. Clip b to [-3.0, 3.0] + 5. Score stride tokens using logits[stride] + b + 6. Warm-start: carry b into next window +``` + +**Key properties:** +- **Strictly causal:** Only trains on already-scored context tokens, applies to new stride tokens +- **No model weight modification:** Operates purely in logit space +- **No hidden state leakage:** Unlike SLOT's delta in hidden space, ETLB adds bias after the LM head +- **Warm-started across windows:** Bias carries forward, learning document-level token preferences +- **Lightweight:** Only `vocab_size` (4096) parameters, SGD optimizer, 5 steps per window + +**Improvement:** Consistent ~0.002 BPB improvement across all 3 seeds + +### How ETLB differs from prior work + +| Method | Space | Cross-window | Modifies weights | Legality | +|--------|-------|-------------|-----------------|----------| +| SLOT (Hu et al.) | Hidden states | Shared delta (leak) | No | ❌ Flagged | +| Dynamic Eval (Krause 2019) | All weights | Yes | Yes | ✅ Legal | +| PR #1318 L-BFGS SLOT | Logits | Yes | No | ✅ Legal | +| **ETLB (ours)** | **Logits** | **Warm-start only** | **No** | **✅ Legal** | + +ETLB is most similar to PR #1318's approach but simpler: SGD instead of L-BFGS, with explicit clipping to prevent drift. + +## Architecture (from PR #1285) + +- Vocab: 4096 (sp4096 BPE tokenizer from sproos/parameter-golf-tokenizers) +- Layers: 11 physical + depth recurrence (layers 4,5 repeated = 13 virtual) +- Model dim: 512, MLP 4× with LeakyReLU(0.5)² +- Attention: GQA 8H/4KV, XSA all 11 layers, Partial RoPE (16 dims) +- Value Embedding: 128d, layers 9,10 +- Skip gates: Sigmoid-gated residual connections +- Optimizer: MuonEq-R, WD=0.090 +- QK_GAIN_INIT: 5.0 +- EMA: 0.997 +- Quantization: Full Hessian GPTQ int6, all 66 layers +- Compression: Brotli-11 + byte-shuffle +- Code: LZMA2 minification wrapper + +## Hyperparameters + +### Training +``` +SEED={1337,42,2025} +MUON_WD=0.090 +EMBED_WD=0.090 +QK_GAIN_INIT=5.0 +``` + +### Pre-Quant TTT +``` +PRE_QUANT_TTT=1 +PRE_QUANT_TTT_LR=0.0005 +PRE_QUANT_TTT_EPOCHS=1 +PRE_QUANT_TTT_FREEZE=9 +PRE_QUANT_TTT_CHUNK=32768 +``` + +### ETLB +``` +ETLB_ENABLED=1 +ETLB_LR=0.05 +ETLB_STEPS=5 +ETLB_CLIP=3.0 +``` + +## Reproduction + +```bash +pip install brotli +SEED=1337 PRE_QUANT_TTT=1 PRE_QUANT_TTT_LR=0.0005 PRE_QUANT_TTT_EPOCHS=1 \ +PRE_QUANT_TTT_FREEZE=9 MUON_WD=0.090 EMBED_WD=0.090 QK_GAIN_INIT=5.0 \ +ETLB_ENABLED=1 ETLB_LR=0.05 ETLB_STEPS=5 ETLB_CLIP=3.0 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Ablation + +| Component | BPB (seed 1337) | Delta | +|-----------|----------------|-------| +| Base (no TTT, no ETLB) | ~1.0960 | — | +| + Pre-Quant TTT | 1.0916 | -0.0044 | +| + ETLB | **1.0897** | -0.0019 | +| **Total improvement** | | **-0.0063** | + +## Acknowledgments + +- PR #1285 (@dexhunter) for the base architecture +- PR #549 (@abaybektursun) for TTT/sliding window framework +- sproos for the official sp4096 tokenizer +- SLOT paper (Hu et al., 2025) for inspiration on delta optimization +- Dynamic Evaluation (Krause et al., 2019) for the concept of eval-time adaptation diff --git a/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/etlb_seed1337.log b/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/etlb_seed1337.log new file mode 100644 index 0000000000..f9400a843f --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/etlb_seed1337.log @@ -0,0 +1,211 @@ +W0405 16:53:03.530000 66909 torch/distributed/run.py:803] +W0405 16:53:03.530000 66909 torch/distributed/run.py:803] ***************************************** +W0405 16:53:03.530000 66909 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0405 16:53:03.530000 66909 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp4096 + disable_layer0_attn: False + distributed: True + ema_decay: 0.997 + embed_lr: 0.6 + embed_wd: 0.09 + embedding_dim: 512 + etlb_clip: 3.0 + etlb_enabled: True + etlb_lr: 0.05 + etlb_steps: 5 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_enabled: True + gptq_reserve_seconds: 10.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/4823c3f7-5341-4cdc-8d13-81777149002a.txt + logit_softcap: 30.0 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mixed_quant: False + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_wd: 0.09 + n_int6_layers: 32 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + parallel_residual: False + parallel_start_layer: 7 + parallel_start_layer_is_physical: True + pre_quant_ttt_chunk_tokens: 32768 + pre_quant_ttt_enabled: True + pre_quant_ttt_epochs: 1 + pre_quant_ttt_freeze_blocks: 9 + pre_quant_ttt_lr: 0.0005 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + recur_layers_str: 4,5 + recur_start_step: 3000 + recur_warmup_steps: 20 + repeat_untie_mlp: none + repeat_untie_mlp_layers: + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: 4823c3f7-5341-4cdc-8d13-81777149002a + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin + val_loss_every: 4000 + ve_dim: 128 + ve_enabled: True + ve_layers: 9,10 + vocab_size: 4096 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 141 +val_tokens: 44795904 +model_params:34401371 +parallel_residual: active=0 start_layer=7 start_mode=physical params=0 +recurrence: layers=[4, 5] start_step=3000 active=0 +repeat_untie_mlp: mode=none layers=[] params=0 +gptq:reserving 10s, effective=590000ms +[rank0]:[W405 16:53:28.890015246 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank6]:[W405 16:53:28.041758341 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank2]:[W405 16:53:28.068900757 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank4]:[W405 16:53:28.071394740 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank3]:[W405 16:53:28.193049562 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank7]:[W405 16:53:28.193557261 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank1]:[W405 16:53:28.195860037 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank5]:[W405 16:53:28.197971127 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +recurrence:prewarm active=1 virtual_layers:13 +recur_warmup_step: 1/20 +recur_warmup_step: 2/20 +recur_warmup_step: 3/20 +recur_warmup_step: 4/20 +recur_warmup_step: 5/20 +recur_warmup_step: 6/20 +recur_warmup_step: 10/20 +recur_warmup_step: 20/20 +0/20000 val_loss: 8.3181 val_bpb: 3.5660 +1/20000 train_loss: 8.3186 train_time: 0.0m tok/s: 8460106 +2/20000 train_loss: 12.3286 train_time: 0.0m tok/s: 8365816 +3/20000 train_loss: 10.8566 train_time: 0.0m tok/s: 8249917 +4/20000 train_loss: 9.2077 train_time: 0.0m tok/s: 8188067 +5/20000 train_loss: 7.9932 train_time: 0.0m tok/s: 8156066 +500/20000 train_loss: 3.0767 train_time: 0.8m tok/s: 7893819 +1000/20000 train_loss: 2.9346 train_time: 1.7m tok/s: 7869176 +1500/20000 train_loss: 2.8356 train_time: 2.5m tok/s: 7862525 +2000/20000 train_loss: 2.8129 train_time: 3.3m tok/s: 7863024 +2500/20000 train_loss: 2.8101 train_time: 4.2m tok/s: 7863306 +3000/20000 train_loss: 2.8053 train_time: 5.0m tok/s: 7863625 +recurrence:activated step:3000 layers:[4, 5] virtual_layers:13 +3500/20000 train_loss: 2.7436 train_time: 6.0m tok/s: 7689449 +4000/20000 train_loss: 2.6716 train_time: 6.9m tok/s: 7563170 +4000/20000 val_loss: 2.6883 val_bpb: 1.1525 +4500/20000 train_loss: 2.7236 train_time: 7.9m tok/s: 7469615 +5000/20000 train_loss: 2.5950 train_time: 8.9m tok/s: 7396307 +5500/20000 train_loss: 2.5446 train_time: 9.8m tok/s: 7337530 +5505/20000 val_loss: 2.5678 val_bpb: 1.1008 +stopping_early: wallclock_cap train_time: 590069ms step: 5505/20000 +peak memory allocated: 30215 MiB reserved: 30244 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.56523285 val_bpb:1.09972808 eval_time:1979ms +pre_quant_ttt:start lr=0.0005 epochs=1 freeze=9/11 +pre_quant_ttt:params trainable=5771280 frozen=28630091 + ttt_epoch:1/1 chunk:50 loss:2.4895 avg:2.5570 time:4.3s + ttt_epoch:1/1 chunk:100 loss:2.5000 avg:2.5323 time:8.4s + ttt_epoch:1/1 chunk:150 loss:2.5013 avg:2.5425 time:12.5s + ttt_epoch:1/1 chunk:200 loss:2.6541 avg:2.5409 time:16.6s + ttt_epoch:1/1 chunk:250 loss:2.6488 avg:2.5551 time:20.7s + ttt_epoch:1/1 chunk:300 loss:2.6645 avg:2.5638 time:24.8s + ttt_epoch:1/1 chunk:350 loss:2.4499 avg:2.5727 time:28.9s + ttt_epoch:1/1 chunk:400 loss:2.5943 avg:2.5770 time:33.0s + ttt_epoch:1/1 chunk:450 loss:2.4820 avg:2.5764 time:37.1s + ttt_epoch:1/1 chunk:500 loss:2.3067 avg:2.5773 time:41.2s + ttt_epoch:1/1 chunk:550 loss:2.5263 avg:2.5767 time:45.4s + ttt_epoch:1/1 chunk:600 loss:2.4939 avg:2.5739 time:49.5s + ttt_epoch:1/1 chunk:650 loss:2.5699 avg:2.5689 time:53.6s + ttt_epoch:1/1 chunk:700 loss:2.7190 avg:2.5687 time:57.7s + ttt_epoch:1/1 chunk:750 loss:2.6283 avg:2.5735 time:61.8s + ttt_epoch:1/1 chunk:800 loss:2.6811 avg:2.5775 time:65.9s + ttt_epoch:1/1 chunk:850 loss:2.6007 avg:2.5781 time:70.0s + ttt_epoch:1/1 chunk:900 loss:2.4841 avg:2.5803 time:74.1s + ttt_epoch:1/1 chunk:950 loss:2.5396 avg:2.5795 time:78.2s + ttt_epoch:1/1 chunk:1000 loss:2.7307 avg:2.5797 time:82.3s + ttt_epoch:1/1 chunk:1050 loss:2.4955 avg:2.5781 time:86.4s + ttt_epoch:1/1 chunk:1100 loss:2.6772 avg:2.5775 time:90.5s + ttt_epoch:1/1 chunk:1150 loss:2.6038 avg:2.5790 time:94.6s + ttt_epoch:1/1 chunk:1200 loss:2.5224 avg:2.5815 time:98.7s + ttt_epoch:1/1 chunk:1250 loss:2.6767 avg:2.5841 time:102.8s + ttt_epoch:1/1 chunk:1300 loss:2.6651 avg:2.5846 time:106.9s + ttt_epoch:1/1 chunk:1350 loss:2.5651 avg:2.5852 time:111.0s + ttt_epoch:1/1 done chunks:1368 avg_loss:2.5847 time:112.5s +pre_quant_ttt:done epochs=1 total_time=112.5s +Serialized model: 132405891 bytes +Code size: 68057 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 66 Hessians in 9.7s +GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search +Serialized model int6+brotli: 16016628 bytes +Total submission size int6+brotli: 16084685 bytes +final_int6_roundtrip val_loss:2.58639181 val_bpb:1.10879903 eval_time:7263ms +final_int6_sliding_window val_loss:2.54631334 val_bpb:1.09161719 eval_time:74216ms +etlb:start windows=87488 lr=0.05 steps=5 clip=3.0 + etlb:window 5000/87488 bias_norm=2.3464 + etlb:window 10000/87488 bias_norm=3.2945 + etlb:window 15000/87488 bias_norm=4.0261 + etlb:window 20000/87488 bias_norm=4.5116 + etlb:window 25000/87488 bias_norm=5.0892 + etlb:window 30000/87488 bias_norm=5.5829 + etlb:window 35000/87488 bias_norm=5.8927 + etlb:window 40000/87488 bias_norm=6.2522 + etlb:window 45000/87488 bias_norm=6.5896 + etlb:window 50000/87488 bias_norm=6.8019 + etlb:window 55000/87488 bias_norm=7.0639 + etlb:window 60000/87488 bias_norm=7.2803 + etlb:window 65000/87488 bias_norm=7.5512 + etlb:window 70000/87488 bias_norm=7.7423 + etlb:window 75000/87488 bias_norm=7.9370 + etlb:window 80000/87488 bias_norm=8.0704 + etlb:window 85000/87488 bias_norm=8.2228 +final_int6_sliding_etlb val_loss:2.54188969 val_bpb:1.08972075 eval_time:1167545ms diff --git a/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/etlb_seed2025.log b/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/etlb_seed2025.log new file mode 100644 index 0000000000..cf82588e69 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/etlb_seed2025.log @@ -0,0 +1,211 @@ +W0405 18:09:41.847000 77840 torch/distributed/run.py:803] +W0405 18:09:41.847000 77840 torch/distributed/run.py:803] ***************************************** +W0405 18:09:41.847000 77840 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0405 18:09:41.847000 77840 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp4096 + disable_layer0_attn: False + distributed: True + ema_decay: 0.997 + embed_lr: 0.6 + embed_wd: 0.09 + embedding_dim: 512 + etlb_clip: 3.0 + etlb_enabled: True + etlb_lr: 0.05 + etlb_steps: 5 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_enabled: True + gptq_reserve_seconds: 10.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/ba917f14-ae8a-4608-8b63-f220cff8e486.txt + logit_softcap: 30.0 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mixed_quant: False + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_wd: 0.09 + n_int6_layers: 32 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + parallel_residual: False + parallel_start_layer: 7 + parallel_start_layer_is_physical: True + pre_quant_ttt_chunk_tokens: 32768 + pre_quant_ttt_enabled: True + pre_quant_ttt_epochs: 1 + pre_quant_ttt_freeze_blocks: 9 + pre_quant_ttt_lr: 0.0005 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + recur_layers_str: 4,5 + recur_start_step: 3000 + recur_warmup_steps: 20 + repeat_untie_mlp: none + repeat_untie_mlp_layers: + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: ba917f14-ae8a-4608-8b63-f220cff8e486 + scalar_lr: 0.02 + seed: 2025 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin + val_loss_every: 4000 + ve_dim: 128 + ve_enabled: True + ve_layers: 9,10 + vocab_size: 4096 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 141 +val_tokens: 44795904 +model_params:34401371 +parallel_residual: active=0 start_layer=7 start_mode=physical params=0 +recurrence: layers=[4, 5] start_step=3000 active=0 +repeat_untie_mlp: mode=none layers=[] params=0 +gptq:reserving 10s, effective=590000ms +[rank4]:[W405 18:10:06.129555800 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank6]:[W405 18:10:06.218534107 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank3]:[W405 18:10:06.387525229 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank1]:[W405 18:10:06.417266602 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank7]:[W405 18:10:06.432731447 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank2]:[W405 18:10:06.461265793 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank5]:[W405 18:10:06.526757126 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank0]:[W405 18:10:06.566648933 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +recurrence:prewarm active=1 virtual_layers:13 +recur_warmup_step: 1/20 +recur_warmup_step: 2/20 +recur_warmup_step: 3/20 +recur_warmup_step: 4/20 +recur_warmup_step: 5/20 +recur_warmup_step: 6/20 +recur_warmup_step: 10/20 +recur_warmup_step: 20/20 +0/20000 val_loss: 8.3178 val_bpb: 3.5659 +1/20000 train_loss: 8.3177 train_time: 0.0m tok/s: 8442568 +2/20000 train_loss: 12.3553 train_time: 0.0m tok/s: 8365411 +3/20000 train_loss: 10.8995 train_time: 0.0m tok/s: 8253380 +4/20000 train_loss: 9.2424 train_time: 0.0m tok/s: 8201253 +5/20000 train_loss: 8.0201 train_time: 0.0m tok/s: 8167046 +500/20000 train_loss: 3.0743 train_time: 0.8m tok/s: 7888616 +1000/20000 train_loss: 2.9320 train_time: 1.7m tok/s: 7874691 +1500/20000 train_loss: 2.8329 train_time: 2.5m tok/s: 7871314 +2000/20000 train_loss: 2.8124 train_time: 3.3m tok/s: 7868735 +2500/20000 train_loss: 2.8084 train_time: 4.2m tok/s: 7867880 +3000/20000 train_loss: 2.8025 train_time: 5.0m tok/s: 7867334 +recurrence:activated step:3000 layers:[4, 5] virtual_layers:13 +3500/20000 train_loss: 2.7427 train_time: 6.0m tok/s: 7693586 +4000/20000 train_loss: 2.6774 train_time: 6.9m tok/s: 7568521 +4000/20000 val_loss: 2.6886 val_bpb: 1.1526 +4500/20000 train_loss: 2.7269 train_time: 7.9m tok/s: 7475151 +5000/20000 train_loss: 2.5932 train_time: 8.9m tok/s: 7401428 +5500/20000 train_loss: 2.5450 train_time: 9.8m tok/s: 7341809 +5508/20000 val_loss: 2.5669 val_bpb: 1.1004 +stopping_early: wallclock_cap train_time: 590071ms step: 5508/20000 +peak memory allocated: 30215 MiB reserved: 30244 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.56421934 val_bpb:1.09929359 eval_time:1983ms +pre_quant_ttt:start lr=0.0005 epochs=1 freeze=9/11 +pre_quant_ttt:params trainable=5771280 frozen=28630091 + ttt_epoch:1/1 chunk:50 loss:2.4862 avg:2.5554 time:4.3s + ttt_epoch:1/1 chunk:100 loss:2.4957 avg:2.5310 time:8.4s + ttt_epoch:1/1 chunk:150 loss:2.4926 avg:2.5409 time:12.5s + ttt_epoch:1/1 chunk:200 loss:2.6519 avg:2.5392 time:16.6s + ttt_epoch:1/1 chunk:250 loss:2.6479 avg:2.5534 time:20.7s + ttt_epoch:1/1 chunk:300 loss:2.6634 avg:2.5623 time:24.8s + ttt_epoch:1/1 chunk:350 loss:2.4493 avg:2.5713 time:28.9s + ttt_epoch:1/1 chunk:400 loss:2.5847 avg:2.5756 time:33.0s + ttt_epoch:1/1 chunk:450 loss:2.4822 avg:2.5750 time:37.2s + ttt_epoch:1/1 chunk:500 loss:2.3073 avg:2.5759 time:41.3s + ttt_epoch:1/1 chunk:550 loss:2.5171 avg:2.5752 time:45.4s + ttt_epoch:1/1 chunk:600 loss:2.4893 avg:2.5724 time:49.5s + ttt_epoch:1/1 chunk:650 loss:2.5666 avg:2.5675 time:53.6s + ttt_epoch:1/1 chunk:700 loss:2.7132 avg:2.5673 time:57.7s + ttt_epoch:1/1 chunk:750 loss:2.6263 avg:2.5721 time:61.8s + ttt_epoch:1/1 chunk:800 loss:2.6747 avg:2.5761 time:65.9s + ttt_epoch:1/1 chunk:850 loss:2.5941 avg:2.5766 time:70.0s + ttt_epoch:1/1 chunk:900 loss:2.4817 avg:2.5788 time:74.1s + ttt_epoch:1/1 chunk:950 loss:2.5443 avg:2.5781 time:78.2s + ttt_epoch:1/1 chunk:1000 loss:2.7289 avg:2.5782 time:82.4s + ttt_epoch:1/1 chunk:1050 loss:2.4926 avg:2.5766 time:86.5s + ttt_epoch:1/1 chunk:1100 loss:2.6712 avg:2.5760 time:90.6s + ttt_epoch:1/1 chunk:1150 loss:2.6052 avg:2.5775 time:94.7s + ttt_epoch:1/1 chunk:1200 loss:2.5242 avg:2.5801 time:98.8s + ttt_epoch:1/1 chunk:1250 loss:2.6753 avg:2.5827 time:102.9s + ttt_epoch:1/1 chunk:1300 loss:2.6629 avg:2.5831 time:107.0s + ttt_epoch:1/1 chunk:1350 loss:2.5615 avg:2.5836 time:111.1s + ttt_epoch:1/1 done chunks:1368 avg_loss:2.5831 time:112.6s +pre_quant_ttt:done epochs=1 total_time=112.6s +Serialized model: 132405891 bytes +Code size: 68057 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 66 Hessians in 9.7s +GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search +Serialized model int6+brotli: 16019410 bytes +Total submission size int6+brotli: 16087467 bytes +final_int6_roundtrip val_loss:2.58463355 val_bpb:1.10804526 eval_time:7330ms +final_int6_sliding_window val_loss:2.54437747 val_bpb:1.09078728 eval_time:74559ms +etlb:start windows=87488 lr=0.05 steps=5 clip=3.0 + etlb:window 5000/87488 bias_norm=2.2588 + etlb:window 10000/87488 bias_norm=3.2042 + etlb:window 15000/87488 bias_norm=3.9333 + etlb:window 20000/87488 bias_norm=4.4328 + etlb:window 25000/87488 bias_norm=4.9564 + etlb:window 30000/87488 bias_norm=5.3711 + etlb:window 35000/87488 bias_norm=5.7371 + etlb:window 40000/87488 bias_norm=6.0803 + etlb:window 45000/87488 bias_norm=6.4005 + etlb:window 50000/87488 bias_norm=6.6637 + etlb:window 55000/87488 bias_norm=6.9385 + etlb:window 60000/87488 bias_norm=7.1796 + etlb:window 65000/87488 bias_norm=7.4431 + etlb:window 70000/87488 bias_norm=7.6416 + etlb:window 75000/87488 bias_norm=7.8569 + etlb:window 80000/87488 bias_norm=7.9928 + etlb:window 85000/87488 bias_norm=8.1315 +final_int6_sliding_etlb val_loss:2.54043357 val_bpb:1.08909651 eval_time:1166224ms diff --git a/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/etlb_seed42.log b/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/etlb_seed42.log new file mode 100644 index 0000000000..85876c0d99 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/etlb_seed42.log @@ -0,0 +1,211 @@ +W0405 17:31:14.870000 69023 torch/distributed/run.py:803] +W0405 17:31:14.870000 69023 torch/distributed/run.py:803] ***************************************** +W0405 17:31:14.870000 69023 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0405 17:31:14.870000 69023 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp4096 + disable_layer0_attn: False + distributed: True + ema_decay: 0.997 + embed_lr: 0.6 + embed_wd: 0.09 + embedding_dim: 512 + etlb_clip: 3.0 + etlb_enabled: True + etlb_lr: 0.05 + etlb_steps: 5 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_enabled: True + gptq_reserve_seconds: 10.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/a12e2dd9-60ef-4ca9-aa7d-2fcf5b082b7a.txt + logit_softcap: 30.0 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mixed_quant: False + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_wd: 0.09 + n_int6_layers: 32 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + parallel_residual: False + parallel_start_layer: 7 + parallel_start_layer_is_physical: True + pre_quant_ttt_chunk_tokens: 32768 + pre_quant_ttt_enabled: True + pre_quant_ttt_epochs: 1 + pre_quant_ttt_freeze_blocks: 9 + pre_quant_ttt_lr: 0.0005 + qk_gain_init: 5.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + recur_layers_str: 4,5 + recur_start_step: 3000 + recur_warmup_steps: 20 + repeat_untie_mlp: none + repeat_untie_mlp_layers: + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: a12e2dd9-60ef-4ca9-aa7d-2fcf5b082b7a + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin + val_loss_every: 4000 + ve_dim: 128 + ve_enabled: True + ve_layers: 9,10 + vocab_size: 4096 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 141 +val_tokens: 44795904 +model_params:34401371 +parallel_residual: active=0 start_layer=7 start_mode=physical params=0 +recurrence: layers=[4, 5] start_step=3000 active=0 +repeat_untie_mlp: mode=none layers=[] params=0 +gptq:reserving 10s, effective=590000ms +[rank6]:[W405 17:31:39.887317214 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank0]:[W405 17:31:39.062962829 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank5]:[W405 17:31:39.093977328 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank4]:[W405 17:31:39.122443959 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank1]:[W405 17:31:39.163284555 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank7]:[W405 17:31:39.164285174 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank2]:[W405 17:31:39.184574983 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +[rank3]:[W405 17:31:39.184584270 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +recurrence:prewarm active=1 virtual_layers:13 +recur_warmup_step: 1/20 +recur_warmup_step: 2/20 +recur_warmup_step: 3/20 +recur_warmup_step: 4/20 +recur_warmup_step: 5/20 +recur_warmup_step: 6/20 +recur_warmup_step: 10/20 +recur_warmup_step: 20/20 +0/20000 val_loss: 8.3142 val_bpb: 3.5644 +1/20000 train_loss: 8.3140 train_time: 0.0m tok/s: 8318318 +2/20000 train_loss: 12.2997 train_time: 0.0m tok/s: 8245327 +3/20000 train_loss: 10.8362 train_time: 0.0m tok/s: 8173066 +4/20000 train_loss: 9.2053 train_time: 0.0m tok/s: 8128116 +5/20000 train_loss: 8.0047 train_time: 0.0m tok/s: 8114679 +500/20000 train_loss: 3.0693 train_time: 0.8m tok/s: 7887338 +1000/20000 train_loss: 2.9289 train_time: 1.7m tok/s: 7875900 +1500/20000 train_loss: 2.8387 train_time: 2.5m tok/s: 7872060 +2000/20000 train_loss: 2.8088 train_time: 3.3m tok/s: 7870137 +2500/20000 train_loss: 2.8076 train_time: 4.2m tok/s: 7868925 +3000/20000 train_loss: 2.8063 train_time: 5.0m tok/s: 7868356 +recurrence:activated step:3000 layers:[4, 5] virtual_layers:13 +3500/20000 train_loss: 2.7435 train_time: 6.0m tok/s: 7694649 +4000/20000 train_loss: 2.6752 train_time: 6.9m tok/s: 7569551 +4000/20000 val_loss: 2.6896 val_bpb: 1.1530 +4500/20000 train_loss: 2.7277 train_time: 7.9m tok/s: 7474668 +5000/20000 train_loss: 2.5951 train_time: 8.9m tok/s: 7400802 +5500/20000 train_loss: 2.5446 train_time: 9.8m tok/s: 7341820 +5508/20000 val_loss: 2.5692 val_bpb: 1.1014 +stopping_early: wallclock_cap train_time: 590075ms step: 5508/20000 +peak memory allocated: 30215 MiB reserved: 30244 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.56673998 val_bpb:1.10037419 eval_time:1983ms +pre_quant_ttt:start lr=0.0005 epochs=1 freeze=9/11 +pre_quant_ttt:params trainable=5771280 frozen=28630091 + ttt_epoch:1/1 chunk:50 loss:2.4900 avg:2.5587 time:4.3s + ttt_epoch:1/1 chunk:100 loss:2.4994 avg:2.5344 time:8.4s + ttt_epoch:1/1 chunk:150 loss:2.5024 avg:2.5444 time:12.5s + ttt_epoch:1/1 chunk:200 loss:2.6565 avg:2.5424 time:16.5s + ttt_epoch:1/1 chunk:250 loss:2.6493 avg:2.5565 time:20.6s + ttt_epoch:1/1 chunk:300 loss:2.6623 avg:2.5653 time:24.7s + ttt_epoch:1/1 chunk:350 loss:2.4584 avg:2.5741 time:28.8s + ttt_epoch:1/1 chunk:400 loss:2.5949 avg:2.5785 time:32.9s + ttt_epoch:1/1 chunk:450 loss:2.4869 avg:2.5779 time:37.0s + ttt_epoch:1/1 chunk:500 loss:2.3069 avg:2.5788 time:41.1s + ttt_epoch:1/1 chunk:550 loss:2.5227 avg:2.5781 time:45.2s + ttt_epoch:1/1 chunk:600 loss:2.4945 avg:2.5753 time:49.3s + ttt_epoch:1/1 chunk:650 loss:2.5728 avg:2.5705 time:53.4s + ttt_epoch:1/1 chunk:700 loss:2.7183 avg:2.5702 time:57.5s + ttt_epoch:1/1 chunk:750 loss:2.6306 avg:2.5750 time:61.6s + ttt_epoch:1/1 chunk:800 loss:2.6874 avg:2.5790 time:65.6s + ttt_epoch:1/1 chunk:850 loss:2.5996 avg:2.5796 time:69.7s + ttt_epoch:1/1 chunk:900 loss:2.4831 avg:2.5818 time:73.8s + ttt_epoch:1/1 chunk:950 loss:2.5436 avg:2.5810 time:77.9s + ttt_epoch:1/1 chunk:1000 loss:2.7306 avg:2.5812 time:82.0s + ttt_epoch:1/1 chunk:1050 loss:2.4936 avg:2.5797 time:86.1s + ttt_epoch:1/1 chunk:1100 loss:2.6825 avg:2.5791 time:90.2s + ttt_epoch:1/1 chunk:1150 loss:2.6041 avg:2.5807 time:94.3s + ttt_epoch:1/1 chunk:1200 loss:2.5226 avg:2.5832 time:98.4s + ttt_epoch:1/1 chunk:1250 loss:2.6784 avg:2.5858 time:102.5s + ttt_epoch:1/1 chunk:1300 loss:2.6707 avg:2.5863 time:106.6s + ttt_epoch:1/1 chunk:1350 loss:2.5657 avg:2.5869 time:110.6s + ttt_epoch:1/1 done chunks:1368 avg_loss:2.5864 time:112.1s +pre_quant_ttt:done epochs=1 total_time=112.1s +Serialized model: 132405891 bytes +Code size: 68057 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 66 Hessians in 9.6s +GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search +Serialized model int6+brotli: 16024230 bytes +Total submission size int6+brotli: 16092287 bytes +final_int6_roundtrip val_loss:2.58872995 val_bpb:1.10980140 eval_time:7259ms +final_int6_sliding_window val_loss:2.54867963 val_bpb:1.09263163 eval_time:73863ms +etlb:start windows=87488 lr=0.05 steps=5 clip=3.0 + etlb:window 5000/87488 bias_norm=2.4011 + etlb:window 10000/87488 bias_norm=3.3965 + etlb:window 15000/87488 bias_norm=4.1543 + etlb:window 20000/87488 bias_norm=4.6402 + etlb:window 25000/87488 bias_norm=5.1957 + etlb:window 30000/87488 bias_norm=5.6600 + etlb:window 35000/87488 bias_norm=6.0187 + etlb:window 40000/87488 bias_norm=6.3890 + etlb:window 45000/87488 bias_norm=6.6840 + etlb:window 50000/87488 bias_norm=6.9528 + etlb:window 55000/87488 bias_norm=7.2313 + etlb:window 60000/87488 bias_norm=7.4304 + etlb:window 65000/87488 bias_norm=7.6937 + etlb:window 70000/87488 bias_norm=7.8993 + etlb:window 75000/87488 bias_norm=8.1016 + etlb:window 80000/87488 bias_norm=8.2777 + etlb:window 85000/87488 bias_norm=8.4474 +final_int6_sliding_etlb val_loss:2.54391190 val_bpb:1.09058768 eval_time:1155741ms diff --git a/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/submission.json b/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/submission.json new file mode 100644 index 0000000000..096a218a50 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/submission.json @@ -0,0 +1,48 @@ +{ + "title": "Pre-Quant TTT + ETLB (Eval-Time Logit Bias)", + "author": "AnubhavBharadwaaj", + "track": "10min_16mb", + "hardware": "8xH100_SXM", + "base_pr": 1285, + "techniques": [ + "MuonEq-R optimizer", + "Depth recurrence (layers 4,5)", + "All-Int6 GPTQ quantization", + "Pre-Quantization Test-Time Training (novel)", + "Eval-Time Logit Bias / ETLB (novel)" + ], + "results": { + "seed_1337": { + "sliding_bpb": 1.0897, + "artifact_bytes": 16084685, + "steps": 5505, + "tok_per_sec": 7863 + }, + "seed_42": { + "sliding_bpb": 1.0906, + "artifact_bytes": 16092287, + "steps": 5508, + "tok_per_sec": 7868 + }, + "seed_2025": { + "sliding_bpb": 1.0891, + "artifact_bytes": 16087467, + "steps": 5508, + "tok_per_sec": 7867 + }, + "mean_bpb": 1.0898, + "std_bpb": 0.0008 + }, + "hyperparameters": { + "muon_wd": 0.090, + "embed_wd": 0.090, + "qk_gain_init": 5.0, + "pre_quant_ttt_lr": 0.0005, + "pre_quant_ttt_epochs": 1, + "pre_quant_ttt_freeze_blocks": 9, + "pre_quant_ttt_chunk_tokens": 32768, + "etlb_lr": 0.05, + "etlb_steps": 5, + "etlb_clip": 3.0 + } +} diff --git a/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/train_gpt.py b/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/train_gpt.py new file mode 100644 index 0000000000..defe4fd775 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_PreQuantTTT_ETLB_MuonEqR_DepthRecurrence_AllInt6/train_gpt.py @@ -0,0 +1,729 @@ +_g='momentum' +_f='fineweb_train_*.bin' +_e='LOCAL_RANK' +_d='WARMUP_STEPS' +_c='passthrough_ctrl' +_b='passthrough' +_a='repeat_mlp' +_Z='disable_attn' +_Y='full' +_X='0 else 0;bc=(nt-1-phase)//seq_len;self._cursor_phase[si]=phase;self._cursor_block_count[si]=bc;self._cursor_next[si]=0;self._cursor_start[si]=int(self._rng.integers(bc))if bc>1 else 0;self._cursor_stride[si]=self._pick_coprime_stride(bc);self._cursor_init[si]=_B + def _ensure_cursor(self,si,seq_len): + if not self._cursor_init[si]or self._cursor_next[si]>=self._cursor_block_count[si]:self._reset_cursor(si,seq_len) + def _take_from_shard(self,si,seq_len,count,out): + rem=count + while rem>0: + self._ensure_cursor(si,seq_len);bc=int(self._cursor_block_count[si]);ni=int(self._cursor_next[si]);take=min(rem,bc-ni);phase=int(self._cursor_phase[si]);start=int(self._cursor_start[si]);stride=int(self._cursor_stride[si]) + for j in range(take):bi=(start+(ni+j)*stride)%bc;out.append((si,phase+bi*seq_len)) + self._cursor_next[si]=ni+take;rem-=take + def _init_pipeline(self,global_tokens,seq_len,grad_accum_steps):local_tokens=global_tokens//(self.world_size*grad_accum_steps);num_seqs=local_tokens//seq_len;global_num_seqs=num_seqs*self.world_size;self._cfg=local_tokens,seq_len,num_seqs,global_num_seqs;bbc=(self._num_tokens-1)//seq_len;eligible=bbc>0;self._eligible_shards=np.nonzero(eligible)[0].astype(np.int64);self._base_block_counts=bbc[self._eligible_shards].astype(np.int64) + def _sample_global_windows(self): + _,seq_len,_,gns=self._cfg;ec=int(self._eligible_shards.size);progress=min(self._batches_built/18e2,_D);remaining=np.empty(ec,dtype=np.float64) + for(i,si)in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]:r=int(self._cursor_block_count[si])-int(self._cursor_next[si]);remaining[i]=float(max(r,1)) + else:remaining[i]=float(self._base_block_counts[i]) + alpha=.9-.4*progress;weights=np.power(remaining,alpha);ws=float(weights.sum()) + if not np.isfinite(ws)or ws<=_F:weights=np.ones(ec,dtype=np.float64);ws=float(weights.sum()) + probs=weights/ws;low=min(max(8,self.world_size),ec,gns);high=min(max(32,self.world_size*8),ec,gns);mix=max(1,min(int(round(low+progress*(high-low))),ec,gns));cp=self._rng.choice(ec,size=mix,replace=_C,p=probs);cs=self._eligible_shards[cp];cpr=probs[cp].copy();cpr/=cpr.sum();counts=np.ones(mix,dtype=np.int64);extra=gns-mix + if extra>0:counts+=self._rng.multinomial(extra,cpr).astype(np.int64) + perm=self._rng.permutation(mix);cs,counts=cs[perm],counts[perm];buckets=[] + for(si,cnt)in zip(cs.tolist(),counts.tolist()): + b=[];self._take_from_shard(int(si),seq_len,int(cnt),b) + if b: + if len(b)>1:bp=self._rng.permutation(len(b));b=[b[int(k)]for k in bp.tolist()] + buckets.append(b) + windows=[];active=[i for(i,bk)in enumerate(buckets)if bk] + while active: + order=self._rng.permutation(len(active));new_active=[] + for oi in order.tolist(): + bi=active[oi] + if buckets[bi]:windows.append(buckets[bi].pop()) + if buckets[bi]:new_active.append(bi) + active=new_active + return windows + def next_batch(self,global_tokens,seq_len,grad_accum_steps): + if self._cfg is _A:self._init_pipeline(global_tokens,seq_len,grad_accum_steps) + _,_,num_seqs,_=self._cfg;gw=self._sample_global_windows();local_w=gw[self.rank::self.world_size];x=torch.empty((num_seqs,seq_len),dtype=torch.int64);y=torch.empty((num_seqs,seq_len),dtype=torch.int64) + for(bi,(si,pos))in enumerate(local_w):mm=_get_shard_memmap(self.files[si]);window=torch.as_tensor(np.array(mm[pos:pos+seq_len+1],dtype=np.int64));x[bi]=window[:-1];y[bi]=window[1:] + self._batches_built+=1;return x.to(self.device,non_blocking=_B),y.to(self.device,non_blocking=_B) +class RMSNorm(nn.Module): + def __init__(self,eps=_A):super().__init__();self.eps=eps + def forward(self,x):return F.rms_norm(x,(x.size(-1),),eps=self.eps) +class CastedLinear(nn.Linear): + def forward(self,x):w=self.weight.to(x.dtype);bias=self.bias.to(x.dtype)if self.bias is not _A else _A;return F.linear(x,w,bias) +class Rotary(nn.Module): + def __init__(self,dim,base=1e4,train_seq_len=1024,rope_dims=0):super().__init__();self.dim=dim;self.base=base;self.train_seq_len=train_seq_len;self.rope_dims=rope_dims if rope_dims>0 else dim;inv_freq=_D/base**(torch.arange(0,self.rope_dims,2,dtype=torch.float32)/self.rope_dims);self.register_buffer('inv_freq',inv_freq,persistent=_C);self._seq_len_cached=0;self._cos_cached=_A;self._sin_cached=_A + def forward(self,seq_len,device,dtype): + if self._cos_cached is _A or self._sin_cached is _A or self._seq_len_cached!=seq_len or self._cos_cached.device!=device: + rd=self.rope_dims + if seq_len>self.train_seq_len:scale=seq_len/self.train_seq_len;new_base=self.base*scale**(rd/(rd-2));inv_freq=_D/new_base**(torch.arange(0,rd,2,dtype=torch.float32,device=device)/rd) + else:inv_freq=self.inv_freq.to(device) + t=torch.arange(seq_len,device=device,dtype=inv_freq.dtype);freqs=torch.outer(t,inv_freq);self._cos_cached=freqs.cos()[_A,:,_A,:];self._sin_cached=freqs.sin()[_A,:,_A,:];self._seq_len_cached=seq_len + return self._cos_cached.to(dtype=dtype),self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + if rope_dims>0 and rope_dims0: + head_dim=h.model_dim//h.num_heads + for block in self.blocks:block.attn.rope_dims=h.rope_dims;block.attn.rotary=Rotary(head_dim,base=h.rope_base,train_seq_len=h.train_seq_len,rope_dims=h.rope_dims) + self.ve_layer_indices=[int(x)for x in h.ve_layers.split(',')if x.strip()]if h.ve_enabled else[];kv_dim=self._ve_target_dim + if self.ve_layer_indices:self.ve_shared=ValueEmbedding(h.vocab_size,h.ve_dim,kv_dim);self.ve_layer_scales=nn.ParameterList([nn.Parameter(torch.ones(1,dtype=torch.float32))for _ in self.ve_layer_indices]) + else:self.ve_shared=_A;self.ve_layer_scales=nn.ParameterList() + self.value_embeds=nn.ModuleList();self.final_norm=RMSNorm();self.lm_head=_A if h.tie_embeddings else CastedLinear(h.embedding_dim,h.vocab_size,bias=_C) + if self.lm_head is not _A:self.lm_head._zero_init=_B + if h.xsa_last_n>0: + for i in range(max(0,h.num_layers-h.xsa_last_n),h.num_layers):self.blocks[i].attn.use_xsa=_B + self.parallel_residual=bool(h.parallel_residual);self.parallel_start_layer=max(0,int(h.parallel_start_layer));self.parallel_start_layer_is_physical=bool(h.parallel_start_layer_is_physical);self.parallel_post_lambdas=nn.Parameter(torch.ones(h.num_layers,2,2,dtype=torch.float32))if self.parallel_residual else _A;self.parallel_resid_lambdas=nn.Parameter(torch.full((h.num_layers,2),1.1**.5,dtype=torch.float32))if self.parallel_residual else _A;self.recur_layers=_parse_layer_list(h.recur_layers_str) + for rl in self.recur_layers: + if not 0<=rl=64 and module.weight.shape[1]>=64:nn.init.orthogonal_(module.weight,gain=_D) + def _get_ve(self,layer_idx,input_ids,ve_cache=_A): + A='ve' + if self.ve_shared is _A or layer_idx not in self.ve_layer_indices:return + if ve_cache is not _A and A not in ve_cache:ve_cache[A]=self.ve_shared(input_ids) + ve_base=ve_cache[A]if ve_cache is not _A else self.ve_shared(input_ids);ve_idx=self.ve_layer_indices.index(layer_idx);return ve_base*self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def set_recurrence_active(self,active):self._recurrence_active=bool(active)and bool(self.recur_layers) + def prime_repeat_mlp(self): + if not self.repeat_mlp:return + with torch.no_grad(): + for(repeat_idx,physical_idx)in enumerate(self.recur_layers): + repeat_mlp=self.repeat_mlp[repeat_idx];base_mlp=self.blocks[physical_idx].mlp + if repeat_mlp.fc is not _A:repeat_mlp.fc.weight.copy_(base_mlp.fc.weight) + if repeat_mlp.proj is not _A:repeat_mlp.proj.weight.copy_(base_mlp.proj.weight) + def _get_virtual_layers(self): + if self._recurrence_active and self.recur_layers:return list(range(self._repeat_cutoff))+self.recur_layers+list(range(self._repeat_cutoff,len(self.blocks))) + return list(range(len(self.blocks))) + def _get_repeat_mlp(self,virtual_idx,physical_idx): + if not self._recurrence_active or not self.recur_layers or not self.repeat_mlp:return + repeat_start=self._repeat_cutoff;repeat_end=repeat_start+len(self.recur_layers) + if repeat_start<=virtual_idx=self.parallel_start_layer + return virtual_idx>=self.parallel_start_layer + def _mix_with_x0(self,lane,x0,resid_mix):mix=resid_mix.to(dtype=lane.dtype);return mix[0][_A,_A,:]*lane+mix[1][_A,_A,:]*x0 + def _apply_skip_single(self,x,skip,i): + if isinstance(skip,tuple):skip=skip[1] + scaled_skip=self.skip_weights[i].to(dtype=x.dtype)[_A,_A,:]*skip + if self.skip_gates is not _A:g=torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[_A,_A,:];return torch.lerp(scaled_skip,x,g) + return x+scaled_skip + def _apply_skip_parallel(self,lane0,lane1,skip,i): + if isinstance(skip,tuple):skip0,skip1=skip + else:skip0=skip1=skip + w=self.skip_weights[i].to(dtype=lane0.dtype)[_A,_A,:] + if self.skip_gates is _A:return lane0+w*skip0,lane1+w*skip1 + g=torch.sigmoid(self.skip_gates[i].to(dtype=lane0.dtype))[_A,_A,:];return torch.lerp(w*skip0,lane0,g),torch.lerp(w*skip1,lane1,g) + def _final_parallel_hidden(self,lane0,lane1):return(lane0+lane1)*.5 + def _block_forward(self,block,x,x0,v_embed=_A,repeat_mlp=_A): + mix=block.resid_mix.to(dtype=x.dtype);x_in=mix[0][_A,_A,:]*x+mix[1][_A,_A,:]*x0;x_out=x_in + if not getattr(block,_Z,_C):attn_in=block.attn_norm(x_in)*block.ln_scale_factor;attn_out=block.attn(attn_in,v_embed=v_embed);x_out=x_out+block.attn_scale.to(dtype=x_in.dtype)[_A,_A,:]*attn_out + mlp_in=block.mlp_norm(x_out)*block.ln_scale_factor;mlp_out=repeat_mlp(mlp_in)if repeat_mlp is not _A else block.mlp(mlp_in);return x_out+block.mlp_scale.to(dtype=x_out.dtype)[_A,_A,:]*mlp_out + def _parallel_block(self,block,lane0,lane1,x0,physical_idx,v_embed=_A,repeat_mlp=_A): + if not getattr(block,_Z,_C):attn_read=self._mix_with_x0(lane0,x0,block.resid_mix);attn_in=block.attn_norm(attn_read)*block.ln_scale_factor;attn_out=block.attn(attn_in,v_embed=v_embed);attn_out=block.attn_scale.to(dtype=lane0.dtype)[_A,_A,:]*attn_out;resid=self.parallel_resid_lambdas[physical_idx,0].to(dtype=lane0.dtype);post=self.parallel_post_lambdas[physical_idx,0].to(dtype=lane0.dtype);lane0=resid*lane0+post[0]*attn_out;lane1=resid*lane1+post[1]*attn_out + mlp_read=self._mix_with_x0(lane1,x0,block.resid_mix);mlp_in=block.mlp_norm(mlp_read)*block.ln_scale_factor;mlp_out=repeat_mlp(mlp_in)if repeat_mlp is not _A else block.mlp(mlp_in);mlp_out=block.mlp_scale.to(dtype=lane1.dtype)[_A,_A,:]*mlp_out;resid=self.parallel_resid_lambdas[physical_idx,1].to(dtype=lane0.dtype);post=self.parallel_post_lambdas[physical_idx,1].to(dtype=lane0.dtype);lane0=resid*lane0+post[0]*mlp_out;lane1=resid*lane1+post[1]*mlp_out;return lane0,lane1 + def _backbone(self,input_ids): + x=self.tok_emb(input_ids);x=F.rms_norm(x,(x.size(-1),)) + if self.embed_proj is not _A:x=self.embed_proj(x) + x0=x;skips=[];ve_cache={};v2p=self._get_virtual_layers();enc_layers=len(v2p)//2;dec_layers=len(v2p)-enc_layers;lane0=_A;lane1=_A + for virtual_idx in range(enc_layers): + physical_idx=v2p[virtual_idx];ve=self._get_ve(physical_idx,input_ids,ve_cache);repeat_mlp=self._get_repeat_mlp(virtual_idx,physical_idx) + if self._parallel_active_for_layer(virtual_idx,physical_idx): + if lane0 is _A:lane0=x;lane1=x + lane0,lane1=self._parallel_block(self.blocks[physical_idx],lane0,lane1,x0,physical_idx,v_embed=ve,repeat_mlp=repeat_mlp);skips.append((lane0,lane1)) + else:x=self._block_forward(self.blocks[physical_idx],x,x0,v_embed=ve,repeat_mlp=repeat_mlp);skips.append(x) + for i in range(dec_layers): + virtual_idx=enc_layers+i;physical_idx=v2p[virtual_idx];ve=self._get_ve(physical_idx,input_ids,ve_cache);repeat_mlp=self._get_repeat_mlp(virtual_idx,physical_idx);skip_i=min(i,self.num_skip_weights-1) + if self._parallel_active_for_layer(virtual_idx,physical_idx): + if lane0 is _A:lane0=x;lane1=x + if skips:lane0,lane1=self._apply_skip_parallel(lane0,lane1,skips.pop(),skip_i) + lane0,lane1=self._parallel_block(self.blocks[physical_idx],lane0,lane1,x0,physical_idx,v_embed=ve,repeat_mlp=repeat_mlp) + else: + if skips:x=self._apply_skip_single(x,skips.pop(),skip_i) + x=self._block_forward(self.blocks[physical_idx],x,x0,v_embed=ve,repeat_mlp=repeat_mlp) + hidden=self._final_parallel_hidden(lane0,lane1)if lane1 is not _A else x;hidden=self.final_norm(hidden) + if self.head_proj is not _A:hidden=self.head_proj(hidden) + return hidden + def compute_logits(self,hidden): + if self.tie_embeddings:logits_proj=F.linear(hidden,self.tok_emb.weight) + else:logits_proj=self.lm_head(hidden) + return self.logit_softcap*torch.tanh(logits_proj/self.logit_softcap) + def forward_hidden(self,input_ids):return self._backbone(input_ids) + def forward_logits(self,input_ids):return self.compute_logits(self.forward_hidden(input_ids)) + def forward(self,input_ids,target_ids):logits=self.forward_logits(input_ids);return F.cross_entropy(logits.reshape(-1,logits.size(-1)).float(),target_ids.reshape(-1),reduction='mean') +def classify_param(name): + A='.mlp.' + if'tok_emb'in name or'lm_head'in name:return'embed' + if _a in name:return _O + if A in name:return _O + if'.attn.'in name or'.proj.'in name and A not in name:return _T + return'other' +@torch.compile +def zeropower_via_newtonschulz5(G,steps=10,eps=1e-07): + a,b,c=3.4445,-4.775,2.0315;X=G.bfloat16();X/=X.norm()+eps;transposed=G.size(0)>G.size(1) + if transposed:X=X.T + for _ in range(steps):A=X@X.T;B=b*A+c*A@A;X=a*X+B@X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self,params,lr,momentum,backend_steps,nesterov=_B,weight_decay=_F):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay)) + @torch.no_grad() + def step(self,closure=_A): + A='momentum_buffer';loss=_A + if closure is not _A: + with torch.enable_grad():loss=closure() + distributed=dist.is_available()and dist.is_initialized();world_size=dist.get_world_size()if distributed else 1;rank=dist.get_rank()if distributed else 0 + for group in self.param_groups: + params=group[_P] + if not params:continue + lr=group[_H];momentum=group[_g];backend_steps=group['backend_steps'];nesterov=group['nesterov'];total_params=sum(int(p.numel())for p in params);updates_flat=torch.zeros(total_params,device=params[0].device,dtype=torch.bfloat16);curr=0 + for(i,p)in enumerate(params): + if i%world_size==rank and p.grad is not _A: + g=p.grad;state=self.state[p] + if A not in state:state[A]=torch.zeros_like(g) + buf=state[A];buf.mul_(momentum).add_(g) + if nesterov:g=g.add(buf,alpha=momentum) + row_norms=g.float().norm(dim=-1,keepdim=_B).clamp_min(1e-07);g=g/row_norms.to(g.dtype);g=zeropower_via_newtonschulz5(g,steps=backend_steps);g*=max(1,g.size(0)/g.size(1))**.5;updates_flat[curr:curr+p.numel()]=g.reshape(-1) + curr+=p.numel() + if distributed:dist.all_reduce(updates_flat,op=dist.ReduceOp.SUM) + wd=group.get('weight_decay',_F);curr=0 + for p in params: + if wd>_F:p.data.mul_(_D-lr*wd) + g=updates_flat[curr:curr+p.numel()].view_as(p).to(dtype=p.dtype);p.add_(g,alpha=-lr);curr+=p.numel() + return loss +class Optimizers: + def __init__(self,h,base_model): + named_params=list(base_model.blocks.named_parameters());named_params.extend(list(base_model.repeat_mlp.named_parameters()));matrix_params=[p for(name,p)in named_params if p.ndim==2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)];scalar_params=[p for(name,p)in named_params if p.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel()>0:scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not _A and base_model.skip_gates.numel()>0:scalar_params.append(base_model.skip_gates) + if base_model.parallel_post_lambdas is not _A:scalar_params.append(base_model.parallel_post_lambdas) + if base_model.parallel_resid_lambdas is not _A:scalar_params.append(base_model.parallel_resid_lambdas) + token_lr=h.tied_embed_lr if h.tie_embeddings else h.embed_lr;tok_params=[{_P:[base_model.tok_emb.weight],_H:token_lr,_I:token_lr}] + if base_model.ve_shared is not _A: + tok_params.append({_P:[base_model.ve_shared.embed.weight],_H:token_lr,_I:token_lr}) + if base_model.ve_shared.proj is not _A:matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for scale in base_model.ve_layer_scales:scalar_params.append(scale) + self.optimizer_tok=torch.optim.AdamW(tok_params,betas=(h.beta1,h.beta2),eps=h.adam_eps,weight_decay=h.embed_wd,fused=_B);self.optimizer_muon=Muon(matrix_params,lr=h.matrix_lr,momentum=h.muon_momentum,backend_steps=h.muon_backend_steps,weight_decay=h.muon_wd) + for group in self.optimizer_muon.param_groups:group[_I]=h.matrix_lr + self.optimizer_scalar=torch.optim.AdamW([{_P:scalar_params,_H:h.scalar_lr,_I:h.scalar_lr}],betas=(h.beta1,h.beta2),eps=h.adam_eps,weight_decay=h.adam_wd,fused=_B);self.optimizers=[self.optimizer_tok,self.optimizer_muon,self.optimizer_scalar] + if base_model.lm_head is not _A:self.optimizer_head=torch.optim.Adam([{_P:[base_model.lm_head.weight],_H:h.head_lr,_I:h.head_lr}],betas=(h.beta1,h.beta2),eps=h.adam_eps,fused=_B);self.optimizers.insert(1,self.optimizer_head) + else:self.optimizer_head=_A + def __iter__(self):return iter(self.optimizers) + def zero_grad_all(self): + for opt in self.optimizers:opt.zero_grad(set_to_none=_B) + def step(self): + for opt in self.optimizers:opt.step() + self.zero_grad_all() +CONTROL_TENSOR_NAME_PATTERNS=tuple(pattern for pattern in os.environ.get('CONTROL_TENSOR_NAME_PATTERNS','attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates,ve_layer_scales,ve_shared.scale,parallel_post_lambdas,parallel_resid_lambdas').split(',')if pattern) +INT8_PER_ROW_SCALE_DTYPE=torch.float16 +INT8_CLIP_PERCENTILE=99.99984 +INT8_CLIP_Q=INT8_CLIP_PERCENTILE/1e2 +def quantize_float_tensor(t): + t32=t.float() + if t32.ndim==2:clip_abs=torch.quantile(t32.abs(),INT8_CLIP_Q,dim=1)if t32.numel()else torch.empty((t32.shape[0],),dtype=torch.float32);clipped=torch.maximum(torch.minimum(t32,clip_abs[:,_A]),-clip_abs[:,_A]);scale=(clip_abs/127.).clamp_min(_D/127.);q=torch.clamp(torch.round(clipped/scale[:,_A]),-127,127).to(torch.int8).contiguous();return q,scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs=float(torch.quantile(t32.abs().flatten(),INT8_CLIP_Q).item())if t32.numel()else _F;scale=torch.tensor(clip_abs/127. if clip_abs>0 else _D,dtype=torch.float32);q=torch.clamp(torch.round(torch.clamp(t32,-clip_abs,clip_abs)/scale),-127,127).to(torch.int8).contiguous();return q,scale +def restore_fp32_params(model): + for module in model.modules(): + if isinstance(module,CastedLinear):module.float() + for(name,param)in model.named_parameters(): + if(param.ndim<2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS))and param.dtype!=torch.float32:param.data=param.data.float() +def quantize_int6_per_row(t,clip_range=31): + t32=t.float() + if t32.ndim==2: + best_q,best_s,best_err=_A,_A,float('inf') + for pct in[.999,.9995,.9999,.99999,_D]: + if pct<_D:row_clip=torch.quantile(t32.abs(),pct,dim=1) + else:row_clip=t32.abs().amax(dim=1) + s=(row_clip/clip_range).clamp_min(_D/clip_range).to(torch.float16);q=torch.clamp(torch.round(t32/s.float()[:,_A]),-clip_range,clip_range).to(torch.int8);recon=q.float()*s.float()[:,_A];err=(t32-recon).pow(2).mean().item() + if err0 else _D,dtype=torch.float16);q=torch.clamp(torch.round(t32/scale.float()),-clip_range,clip_range).to(torch.int8);return q,scale +def collect_hessians(model,train_loader,h,device,n_calibration_batches=64): + A='.weight';hessians={};hooks=[] + def make_hook(name): + def hook_fn(module,inp,out): + x=inp[0].detach().float() + if x.ndim==3:x=x.reshape(-1,x.shape[-1]) + if name not in hessians:hessians[name]=torch.zeros(x.shape[1],x.shape[1],dtype=torch.float32,device=device) + hessians[name].addmm_(x.T,x) + return hook_fn + for(name,module)in model.named_modules(): + if isinstance(module,CastedLinear)and module.weight.numel()>65536: + cat=classify_param(name+A) + if cat in(_O,_T):hooks.append(module.register_forward_hook(make_hook(name+A))) + model.eval() + with torch.no_grad(): + for i in range(n_calibration_batches):x,y=train_loader.next_batch(h.train_batch_tokens,h.train_seq_len,h.grad_accum_steps);model.forward_logits(x) + for h in hooks:h.remove() + for name in hessians:hessians[name]=hessians[name].cpu()/n_calibration_batches + return hessians +def gptq_quantize_weight(w,H,clip_range=31,block_size=128): + W_orig=w.float().clone();rows,cols=W_orig.shape;H=H.float().clone();dead=torch.diag(H)==0;H[dead,dead]=1;damp=.01*H.diag().mean();H.diagonal().add_(damp);perm=torch.argsort(H.diag(),descending=_B);invperm=torch.argsort(perm);W_perm=W_orig[:,perm].clone();W_perm[:,dead[perm]]=0;H=H[perm][:,perm] + try:Hinv=torch.cholesky_inverse(torch.linalg.cholesky(H));Hinv=torch.linalg.cholesky(Hinv,upper=_B) + except torch.linalg.LinAlgError:return quantize_int6_per_row(W_orig,clip_range) + best_q,best_scale,best_err=_A,_A,float('inf') + for pct in[.999,.9995,.9999,.99999,_D]: + if pct<_D:row_clip=torch.quantile(W_orig.abs(),pct,dim=1) + else:row_clip=W_orig.abs().amax(dim=1) + s=(row_clip/clip_range).clamp_min(_D/clip_range).to(torch.float16);sf=s.float();Q=torch.zeros(rows,cols,dtype=torch.int8);W_work=W_perm.clone() + for i1 in range(0,cols,block_size): + i2=min(i1+block_size,cols);W_block=W_work[:,i1:i2].clone();Hinv_block=Hinv[i1:i2,i1:i2];Err=torch.zeros(rows,i2-i1) + for j in range(i2-i1):w_col=W_block[:,j];d=Hinv_block[j,j];q_col=torch.clamp(torch.round(w_col/sf),-clip_range,clip_range);Q[:,i1+j]=q_col.to(torch.int8);err=(w_col-q_col.float()*sf)/d;Err[:,j]=err;W_block[:,j:]-=err.unsqueeze(1)*Hinv_block[j,j:].unsqueeze(0) + if i2=1:q,s=quantize_int6_per_row(t,clip_range=layer_clip);bit_label=B if layer_clip==15 else _Q;result[name+_K]=q;result[name+_L]=s;meta[name]={_J:bit_label,A:layer_clip} + else:q,s=quantize_float_tensor(t);result[name+_K]=q;result[name+_L]=s;meta[name]={_J:'int8'} + log(f"GPTQ quantization: {gptq_count} layers with full GPTQ, {fallback_count} fallback to clip-search") + if clip_ranges is not _A:log(f"mixed_quant: {int6_count} int6, {int5_count} int5") + return result,meta +def mixed_quantize_int6(state_dict,int6_cats): + result={};meta={} + for(name,tensor)in state_dict.items(): + t=tensor.detach().cpu().contiguous();cat=classify_param(name) + if not t.is_floating_point()or t.numel()<=65536:result[name]=t.to(torch.float16)if t.is_floating_point()else t;meta[name]=_b;continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):result[name]=t.float();meta[name]=_c;continue + if cat in int6_cats and t.ndim>=1:q,s=quantize_int6_per_row(t);result[name+_K]=q;result[name+_L]=s;meta[name]={_J:_Q} + else:q,s=quantize_float_tensor(t);result[name+_K]=q;result[name+_L]=s;meta[name]={_J:'int8'} + return result,meta +def dequantize_mixed_int6(result,meta,template_sd): + out={} + for(name,orig)in template_sd.items(): + info=meta.get(name) + if info is _A:continue + orig_dtype=orig.dtype + if info in(_b,_c,'passthrough_fp16'): + t=result[name] + if t.dtype==torch.float16 and orig_dtype in(torch.float32,torch.bfloat16):t=t.to(orig_dtype) + out[name]=t;continue + q,s=result[name+_K],result[name+_L] + if s.ndim>0:out[name]=(q.float()*s.float().view(q.shape[0],*[1]*(q.ndim-1))).to(orig_dtype) + else:out[name]=(q.float()*float(s.item())).to(orig_dtype) + return out +_BSHF_MAGIC=b'BSHF' +def _byte_shuffle(data,stride=2): + if stride<=1 or len(data)0 and h.etlb_steps>0: + b=bias.clone().detach().requires_grad_(_B) + opt=torch.optim.SGD([b],lr=h.etlb_lr) + ctx_logits=logits[:s].detach();ctx_tgt=y.squeeze(0)[:s] + for _ in range(h.etlb_steps): + opt.zero_grad();loss=F.cross_entropy(ctx_logits+b,ctx_tgt);loss.backward();opt.step() + bias=b.detach().clamp_(-h.etlb_clip,h.etlb_clip) + scored_logits=logits[s:wlen]+bias + scored_nll=F.cross_entropy(scored_logits,y.squeeze(0)[s:wlen],reduction=_G).to(torch.float64) + loss_sum+=scored_nll.sum();token_count+=float(wlen-s) + tgt=y.squeeze(0)[s:wlen];prev=x.squeeze(0)[s:wlen];tb=val_data.base_bytes_lut[tgt].to(torch.float64) + tb+=(val_data.has_leading_space_lut[tgt]&~val_data.is_boundary_token_lut[prev]).to(torch.float64);byte_count+=tb.sum() + if(wi+1)%5000==0:log(f" etlb:window {wi+1}/{len(my_windows)} bias_norm={bias.norm().item():.4f}") + if dist.is_available()and dist.is_initialized():dist.all_reduce(loss_sum,op=dist.ReduceOp.SUM);dist.all_reduce(token_count,op=dist.ReduceOp.SUM);dist.all_reduce(byte_count,op=dist.ReduceOp.SUM) + base_model.train();return _loss_bpb(loss_sum,token_count,byte_count) +def timed_eval(label,fn,*args,**kwargs):torch.cuda.synchronize();t0=time.perf_counter();val_loss,val_bpb=fn(*args,**kwargs);torch.cuda.synchronize();elapsed_ms=1e3*(time.perf_counter()-t0);log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms");return val_loss,val_bpb +def run_evals(h,device,val_data,eval_model): + compiled_model=torch.compile(eval_model,dynamic=_C,fullgraph=_B);timed_eval('final_int6_roundtrip',eval_val,h,device,val_data,compiled_model) + if h.sliding_window_enabled:timed_eval('final_int6_sliding_window',eval_val_sliding,h,device,val_data,eval_model) + if h.etlb_enabled:timed_eval('final_int6_sliding_etlb',eval_val_sliding_etlb,h,device,val_data,eval_model) +def train_model(h,device,val_data): + B='_recurrence_active';A='parallel_post_lambdas';base_model=GPT(h).to(device).bfloat16();restore_fp32_params(base_model);import torch._dynamo;torch._dynamo.config.cache_size_limit=32;compiled_model=torch.compile(base_model,dynamic=_C,fullgraph=_B) + if h.distributed:model=DDP(compiled_model,device_ids=[h.local_rank],broadcast_buffers=_C,find_unused_parameters=_B) + else:model=compiled_model + log(f"model_params:{sum(p.numel()for p in base_model.parameters())}") + if getattr(base_model,A,_A)is not _A:params=base_model.parallel_post_lambdas.numel()+base_model.parallel_resid_lambdas.numel() + else:params=0 + log(f"parallel_residual: active={int(getattr(base_model,A,_A)is not _A)} start_layer={h.parallel_start_layer} start_mode={"physical"if h.parallel_start_layer_is_physical else"virtual"} params={params}");log(f"recurrence: layers={getattr(base_model,_U,[])} start_step={h.recur_start_step} active={int(getattr(base_model,B,_C))}");log(f"repeat_untie_mlp: mode={h.repeat_untie_mlp} layers={getattr(base_model,"repeat_untie_mlp_layers",[])} params={sum(p.numel()for p in getattr(base_model,_a,[]).parameters())if getattr(base_model,_a,_A)else 0}");optimizers=Optimizers(h,base_model);train_loader=DistributedTokenLoader(h.train_files,h.rank,h.world_size,device);max_wallclock_ms=1e3*h.max_wallclock_seconds if h.max_wallclock_seconds>0 else _A + if h.gptq_enabled and max_wallclock_ms is not _A:max_wallclock_ms-=h.gptq_reserve_seconds*1e3;log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + def training_frac(step,elapsed_ms): + if max_wallclock_ms is _A:return step/max(h.iterations,1) + return elapsed_ms/max(max_wallclock_ms,1e-09) + def lr_mul(frac): + if h.warmdown_frac<=0:return _D + if frac>=_D-h.warmdown_frac:return max((_D-frac)/h.warmdown_frac,h.min_lr) + return _D + def step_fn(step,lr_scale): + optimizers.zero_grad_all();train_loss=torch.zeros((),device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed:model.require_backward_grad_sync=micro_step==h.grad_accum_steps-1 + x,y=train_loader.next_batch(h.train_batch_tokens,h.train_seq_len,h.grad_accum_steps) + with torch.autocast(device_type=_M,dtype=torch.bfloat16,enabled=_B):loss=model(x,y) + train_loss+=loss.detach();(loss/h.grad_accum_steps).backward() + train_loss/=h.grad_accum_steps;frac=min(step/h.muon_momentum_warmup_steps,_D)if h.muon_momentum_warmup_steps>0 else _D;muon_momentum=(1-frac)*h.muon_momentum_warmup_start+frac*h.muon_momentum + for group in optimizers.optimizer_muon.param_groups:group[_g]=muon_momentum + for opt in optimizers: + for group in opt.param_groups:group[_H]=group[_I]*lr_scale + if h.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(base_model.parameters(),h.grad_clip_norm) + optimizers.step();return train_loss + if h.warmup_steps>0: + initial_model_state={name:tensor.detach().cpu().clone()for(name,tensor)in base_model.state_dict().items()};initial_optimizer_states=[copy.deepcopy(opt.state_dict())for opt in optimizers];model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step,_D) + if warmup_step<=5 or(warmup_step+1)%10==0 or warmup_step+1==h.warmup_steps:log(f"warmup_step: {warmup_step+1}/{h.warmup_steps}") + if getattr(base_model,_U,_A)and h.recur_warmup_steps>0: + base_model.prime_repeat_mlp();base_model.set_recurrence_active(_B);log(f"recurrence:prewarm active=1 virtual_layers:{len(base_model._get_virtual_layers())}") + for recur_warmup_step in range(h.recur_warmup_steps): + step_fn(recur_warmup_step,_D) + if recur_warmup_step<=5 or(recur_warmup_step+1)%10==0 or recur_warmup_step+1==h.recur_warmup_steps:log(f"recur_warmup_step: {recur_warmup_step+1}/{h.recur_warmup_steps}") + base_model.set_recurrence_active(_C) + base_model.load_state_dict(initial_model_state,strict=_B) + for(opt,state)in zip(optimizers,initial_optimizer_states,strict=_B):opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed:model.require_backward_grad_sync=_B + train_loader=DistributedTokenLoader(h.train_files,h.rank,h.world_size,device) + ema_state={name:t.detach().float().clone()for(name,t)in base_model.state_dict().items()};ema_decay=h.ema_decay;training_time_ms=_F;stop_after_step=_A;torch.cuda.synchronize();t0=time.perf_counter();step=0 + while _B: + if getattr(base_model,_U,_A)and not getattr(base_model,B,_C)and step>=h.recur_start_step:base_model.prime_repeat_mlp();base_model.set_recurrence_active(_B);log(f"recurrence:activated step:{step} layers:{base_model.recur_layers} virtual_layers:{len(base_model._get_virtual_layers())}") + last_step=step==h.iterations or stop_after_step is not _A and step>=stop_after_step;should_validate=last_step or h.val_loss_every>0 and step%h.val_loss_every==0 + if should_validate:torch.cuda.synchronize();training_time_ms+=1e3*(time.perf_counter()-t0);val_loss,val_bpb=eval_val(h,device,val_data,model);log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}");torch.cuda.synchronize();t0=time.perf_counter() + if last_step: + if stop_after_step is not _A and step0 and(step<=5 or step%h.train_log_every==0 or stop_after_step is not _A) + if should_log_train:tok_per_sec=step*h.train_batch_tokens/(approx_training_time_ms/1e3);log(f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} train_time: {approx_training_time_ms/60000:.1f}m tok/s: {tok_per_sec:.0f}") + reached_cap=max_wallclock_ms is not _A and approx_training_time_ms>=max_wallclock_ms + if h.distributed and max_wallclock_ms is not _A:reached_cap_tensor=torch.tensor(int(reached_cap),device=device);dist.all_reduce(reached_cap_tensor,op=dist.ReduceOp.MAX);reached_cap=bool(reached_cap_tensor.item()) + if stop_after_step is _A and reached_cap:stop_after_step=step + log(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB");log('ema:applying EMA weights');current_state=base_model.state_dict();avg_state={name:t.to(dtype=current_state[name].dtype)for(name,t)in ema_state.items()};base_model.load_state_dict(avg_state,strict=_B);return base_model,compiled_model +def pre_quant_ttt(h,base_model,device,val_data): + """Pre-quantization TTT: adapt EMA weights on val data using score-first protocol. + Adapted weights quantize better under GPTQ. Legal: follows score-first constraint + (score chunk, then train on already-scored tokens). Runs after training, before GPTQ.""" + if not h.pre_quant_ttt_enabled:return + num_blocks=len(base_model.blocks);freeze_n=min(h.pre_quant_ttt_freeze_blocks,num_blocks) + log(f"pre_quant_ttt:start lr={h.pre_quant_ttt_lr} epochs={h.pre_quant_ttt_epochs} freeze={freeze_n}/{num_blocks}") + for i in range(freeze_n): + for p in base_model.blocks[i].parameters():p.requires_grad_(_C) + for i in range(freeze_n,num_blocks): + for p in base_model.blocks[i].parameters():p.requires_grad_(_B) + base_model.tok_emb.weight.requires_grad_(_C) + if hasattr(base_model,'embed_proj')and base_model.embed_proj is not _A: + for p in base_model.embed_proj.parameters():p.requires_grad_(_C) + if hasattr(base_model,'head_proj')and base_model.head_proj is not _A: + for p in base_model.head_proj.parameters():p.requires_grad_(_C) + if base_model.skip_weights is not _A:base_model.skip_weights.requires_grad_(_C) + if base_model.skip_gates is not _A:base_model.skip_gates.requires_grad_(_C) + if base_model.ve_shared is not _A: + for p in base_model.ve_shared.parameters():p.requires_grad_(_C) + for s in base_model.ve_layer_scales:s.requires_grad_(_C) + if hasattr(base_model,'parallel_post_lambdas')and base_model.parallel_post_lambdas is not _A:base_model.parallel_post_lambdas.requires_grad_(_C) + if hasattr(base_model,'parallel_resid_lambdas')and base_model.parallel_resid_lambdas is not _A:base_model.parallel_resid_lambdas.requires_grad_(_C) + trainable=sum(p.numel()for p in base_model.parameters()if p.requires_grad);frozen=sum(p.numel()for p in base_model.parameters()if not p.requires_grad) + log(f"pre_quant_ttt:params trainable={trainable} frozen={frozen}") + ttt_params=[p for p in base_model.parameters()if p.requires_grad] + optimizer=torch.optim.AdamW(ttt_params,lr=h.pre_quant_ttt_lr,weight_decay=0.0,betas=(0.9,0.999)) + val_tokens=val_data.val_tokens;seq_len=h.eval_seq_len;chunk_tokens=h.pre_quant_ttt_chunk_tokens;total_tokens=val_tokens.numel()-1;t0=time.perf_counter() + for epoch in range(h.pre_quant_ttt_epochs): + pos=0;chunk_idx=0;epoch_loss_sum=0.0;epoch_chunks=0 + while pos