diff --git a/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/README.md b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/README.md new file mode 100644 index 0000000000..f7924d3d0d --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/README.md @@ -0,0 +1,91 @@ +# Non-record: Parallel Residuals + Hessian-Aware SDClip (3-seed mean 1.08354 BPB) + +**val bpb: 1.08354** (3-seed mean, std=0.00050) + +Not a record. This is a small 3-seed experiment over [PR #1394](https://github.com/openai/parameter-golf/pull/1394) on my runs, but not enough evidence for a statistical claim — the seed count is too small for confidence. Posting because the changes are zero-cost, reproducible, and may be useful to others trying out different techniques. + +| Seed | Steps | Pre-quant BPB | Post-quant BPB | **Sliding BPB** | Artifact | +|-|-|-|-|-|-| +| 1337 | 5178 | 1.08765 | 1.09959 | **1.08301** | 15,976,275 | +| 42 | 5180 | 1.08816 | 1.10013 | **1.08363** | 15,978,439 | +| 3141 | 5182 | 1.08872 | 1.10044 | **1.08399** | 15,979,649 | +| **Mean** | | 1.08818 | 1.10005 | **1.08354** | 15,978,121 | + +## Changes + +Three zero-cost modifications on top of [PR #1394](https://github.com/openai/parameter-golf/pull/1394), adding zero extra parameters or bytes: + +### 1. Parallel Residuals (Layers 7+) + +GPT-J style parallel attention+MLP ([Wang & Komatsuzaki, 2021](https://github.com/kingoflolz/mesh-transformer-jax)) for the last 4 layers. Both attention and MLP read from the same input and their outputs are added in parallel: + +``` +# Parallel (layers 7-10): +x_out = x + attn_scale * Attn(norm(x)) + mlp_scale * MLP(norm(x)) + +# Sequential (layers 0-6, unchanged): +h = x + attn_scale * Attn(norm(x)) +x_out = h + mlp_scale * MLP(norm(h)) +``` + +I expected parallel residuals to reduce interference between attention and MLP during GPTQ calibration. Pre-quant BPB barely moved, but the quantization gap tightened across all 3 seeds, which made this the most useful change in practice. + +### 2. Hessian-Aware SDClip + +I used GPTQ's existing Hessian diagonal as a cheap importance signal to slightly modulate SDClip thresholds by row: + +$$c_i = k \cdot \sigma_i \cdot [1 + \lambda(r_i - 1)], \quad \lambda = 0.175$$ + +where $\sigma_i$ is the standard deviation of row $i$ and $r_i$ is the row importance derived from Hessian-weighted magnitude. The effect is small but directionally useful at $\lambda = 0.175$; higher $\lambda$ hurt compression. I initially used $\lambda = 0.30$ but found $\lambda = 0.175$ is consistently better across seeds — both lower BPB and smaller artifact. Higher $\lambda$ reduces rounding error but increases entropy, which makes Brotli compression less effective. + +### 3. Progressive Recurrence + +Depth recurrence split into two phases: first loop enabled at 50% of training, second at 65%. The split points were not optimized — 50% matches the original and 65% was a single manual choice. Enabling both loops at once causes a sharper loss spike; splitting gives the model time to adapt to each additional pass before adding the next. + +## Hessian Analysis (Cross-Seed) + +Hessian diagnostics from 3 seeds, 67 matrices each: + +- **Group-level traces** (early/loop/mid/late blocks): $r=0.997$ across seeds +- **Per-matrix traces**: $r=0.994$ +- **Per-row importance**: $r=0.12$ (noise) + +Importance hierarchy: early blocks (30x trace of late blocks) >> loop >> mid >> late. Per-row importance is too noisy to be a reliable signal, but group-level traces are very stable across seeds. This suggests per-group clip allocation could be a useful direction. + +## Future Directions + +Several ideas I'd like to explore with more compute time: + +- **Per-group clip allocation**: Non-uniform $k$ across layer groups, using the stable group-level trace hierarchy as a guide. +- **Output-Hessian weighting**: Using backward-pass gradients for output-side row importance rather than input-side alone. +- **More seeds**: 3 seeds is not enough for strong statistical claims. I'd want 5+ to be confident about the gap vs PR #1394. +- **YAQA**: I like the idea of the paper ([arXiv:2505.22988](https://arxiv.org/abs/2505.22988)), but I couldn't get a working backward pass for it. I think maybe it could be adapted for the parameter golf problem in an interesting way. I also like the math in Mousse ([arXiv:2603.09697](https://arxiv.org/abs/2603.09697)), but exploiting curvature in small LMs seems tough. + +## Run Command + +```bash +HESSIAN_CLIP_LAMBDA=0.175 LOOP_PHASE2_AT=0.65 PARALLEL_RESIDUAL_START=7 SEED=1337 \ +torchrun --standalone --nproc_per_node=8 train_gpt_sweep.py +``` + +## Requirements + +Flash Attention 3 (Hopper) required. SP8192 BPE tokenizer trained on FineWeb 10B (sentencepiece BPE, 8192 vocab). + +```bash +pip install torch --index-url https://download.pytorch.org/whl/cu130 +pip install --no-cache-dir \ + "https://download.pytorch.org/whl/cu130/flash_attn_3-3.0.0-cp39-abi3-manylinux_2_28_x86_64.whl" +pip install -r requirements.txt +``` + +## Compliance (Track A — Fixed Predictor) + +- No TTT, SLOT, n-gram cache, or eval-time adaptation +- GPTQ calibration within training budget +- Standard autoregressive sliding-window eval (stride=64) + + +## Credits + +Learned from and inspired by [PR #1394](https://github.com/openai/parameter-golf/pull/1394) (@clarkkev) — SDClip, depth recurrence, and GPTQ embedding quantization ideas. Parallel residuals from GPT-J ([Wang & Komatsuzaki, 2021](https://github.com/kingoflolz/mesh-transformer-jax)). Additional credits: PR #1204 (@msisovic, depth recurrence), PR #1217 (@bigbag, MuonEq-R), PR #1019 (@abaybektursun, previous SOTA). diff --git a/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/submission.json b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/submission.json new file mode 100644 index 0000000000..e1e16baf85 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/submission.json @@ -0,0 +1,34 @@ +{ + "author": "Robby Sneiderman", + "github_id": "Robby955", + "name": "Non-record: Parallel Residuals + Hessian-Aware SDClip", + "blurb": "Three zero-cost modifications on PR #1394: GPT-J parallel residuals (layers 7+), Hessian-diagonal SDClip modulation (lambda=0.175), two-phase progressive recurrence. 3-seed mean 1.08354 BPB.", + "date": "2026-04-06T00:00:00Z", + "val_loss": 2.7995, + "val_bpb": 1.08354, + "val_bpb_std": 0.00050, + "seeds": [1337, 42, 3141], + "seed_results": { + "1337": { + "val_loss": 2.79749, + "val_bpb": 1.08301, + "artifact_bytes": 15976275, + "steps": 5178 + }, + "42": { + "val_loss": 2.79910, + "val_bpb": 1.08363, + "artifact_bytes": 15978439, + "steps": 5180 + }, + "3141": { + "val_loss": 2.80002, + "val_bpb": 1.08399, + "artifact_bytes": 15979649, + "steps": 5182 + } + }, + "hardware": "8xH100 80GB SXM", + "bytes_total": 15978121, + "based_on": "PR #1394 (@clarkkev)" +} diff --git a/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train.log b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train.log new file mode 100644 index 0000000000..14fce6e5cc --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train.log @@ -0,0 +1,165 @@ +W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] +W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] ***************************************** +W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.997 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.5 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + hessian_clip_lambda: 0.175 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/f3971278-d577-499b-8fde-755434809ba9.txt + logit_softcap: 30.0 + loop_end: 5 + loop_layer_bits: 0 + loop_layer_clip_sigmas: 0.0 + loop_phase2_at: 0.65 + loop_start: 4 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.085 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 4.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: f3971278-d577-499b-8fde-755434809ba9 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 32768 + ttt_enabled: False + ttt_entropy_high: 2.1 + ttt_entropy_low: 1.75 + ttt_epochs: 4 + ttt_freeze_blocks: 2 + ttt_lr: 0.0005 + ttt_ns_steps: 3 + untie_loop_mlps: False + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40540160 +model_params:35943512 +hessian_clip: lambda=0.175 +parallel_residuals: ON (layers 7-10) +progressive_recurrence: phase1=0.5 phase2=0.65 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup_phase1: encoder:[0, 1, 2, 3, 4, 5] decoder:[4, 5, 6, 7, 8, 9, 10] +loop_warmup_p1_step: 1/20 +loop_warmup_p1_step: 2/20 +loop_warmup_p1_step: 3/20 +loop_warmup_p1_step: 4/20 +loop_warmup_p1_step: 5/20 +loop_warmup_p1_step: 6/20 +loop_warmup_p1_step: 10/20 +loop_warmup_p1_step: 20/20 +loop_warmup_phase2: encoder:[0, 1, 2, 3, 4, 5, 4] decoder:[5, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_p2_step: 1/20 +loop_warmup_p2_step: 2/20 +loop_warmup_p2_step: 3/20 +loop_warmup_p2_step: 4/20 +loop_warmup_p2_step: 5/20 +loop_warmup_p2_step: 6/20 +loop_warmup_p2_step: 10/20 +loop_warmup_p2_step: 20/20 +0/20000 val_loss: 9.0052 val_bpb: 3.4862 +1/20000 train_loss: 9.0086 train_time: 0.0m tok/s: 8101558 looping:False +2/20000 train_loss: 12.1911 train_time: 0.0m tok/s: 8047236 looping:False +3/20000 train_loss: 11.0242 train_time: 0.0m tok/s: 7763017 looping:False +4/20000 train_loss: 9.5010 train_time: 0.0m tok/s: 7744899 looping:False +5/20000 train_loss: 8.3911 train_time: 0.0m tok/s: 7764479 looping:False +500/20000 train_loss: 3.3106 train_time: 0.8m tok/s: 7721577 looping:False +1000/20000 train_loss: 3.2012 train_time: 1.7m tok/s: 7711891 looping:False +1500/20000 train_loss: 3.1827 train_time: 2.5m tok/s: 7711479 looping:False +2000/20000 train_loss: 2.9936 train_time: 3.4m tok/s: 7711845 looping:False +2500/20000 train_loss: 3.0679 train_time: 4.2m tok/s: 7713114 looping:False +layer_loop:phase1 step:2884 frac:0.500 +3000/20000 train_loss: 3.1068 train_time: 5.1m tok/s: 7668374 looping:True +3500/20000 train_loss: 2.9483 train_time: 6.1m tok/s: 7510155 looping:True +layer_loop:phase2 step:3634 frac:0.650 +4000/20000 train_loss: 2.9482 train_time: 7.2m tok/s: 7297673 looping:True +4000/20000 val_loss: 2.9279 val_bpb: 1.1335 +4500/20000 train_loss: 2.8499 train_time: 8.3m tok/s: 7110716 looping:True +5000/20000 train_loss: 2.8598 train_time: 9.4m tok/s: 6967320 looping:True +5178/20000 val_loss: 2.8121 val_bpb: 1.0887 +stopping_early: wallclock_cap train_time: 588103ms step: 5178/20000 +peak memory allocated: 34604 MiB reserved: 34634 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.80947408 val_bpb:1.08764765 eval_time:6554ms +Serialized model: 135426937 bytes +Code size: 78688 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 11.3s +GPTQ:saved Hessian diagnostics to hessian_diagnostics.pt (67 matrices) +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +Serialized model quantized+brotli: 15976275 bytes +Total submission size quantized+brotli: 16054963 bytes +quantized val_loss:2.84032998 val_bpb:1.09959307 eval_time:8134ms +quantized_sliding_window val_loss:2.79749368 val_bpb:1.08300961 eval_time:82837ms diff --git a/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_gpt.py b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_gpt.py new file mode 100644 index 0000000000..e6e5ce2b28 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_gpt.py @@ -0,0 +1 @@ +exec(open(__file__.replace("train_gpt.py","train_gpt_decode.py")).read()) diff --git a/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_gpt_decode.py b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_gpt_decode.py new file mode 100644 index 0000000000..ce94647361 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_gpt_decode.py @@ -0,0 +1,1875 @@ +""" +Improved version of Clark's PR #1394 (1.08563 BPB). +Additions: + 1. Hessian-aware SDClip (clip = f(row_std, Hessian importance)) + 2. Untied repeated MLPs (shared attention, pass-specific MLPs on loop passes) + 3. Per-layer quant policy for looped layers (wider clip / more bits) + 4. Progressive recurrence curriculum (phase1 / phase2 instead of binary switch) + 5. QK gain already env-sweepable from Clark's code +""" +import collections +import copy +import glob +import io +import lzma +import math +import os +from pathlib import Path +import random +import re +import subprocess +import sys +import time +import uuid + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch import Tensor, nn + +from flash_attn_interface import flash_attn_func as flash_attn_3_func + +# ---------------------------------------- +# Hyperparameters +# ---------------------------------------- + +class Hyperparameters(): + # Experiment settings + data_dir = os.environ.get('DATA_DIR', './data/') + seed = int(os.environ.get('SEED', 1337)) + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + + # Training length + iterations = int(os.environ.get('ITERATIONS', 20000)) + warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.667)) + warmup_steps = int(os.environ.get('WARMUP_STEPS', 20)) + train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 2048 * 48 * 8)) + train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048)) + train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500)) + max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0)) + + # Validation/Evals + val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 2048 * 32 * 8)) + eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048)) + val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000)) + sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1'))) + + # Model architecture + vocab_size = int(os.environ.get('VOCAB_SIZE', 8192)) + num_layers = int(os.environ.get('NUM_LAYERS', 11)) + xsa_last_n = int(os.environ.get('XSA_LAST_N', 11)) + model_dim = int(os.environ.get('MODEL_DIM', 512)) + embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512)) + num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4)) + num_heads = int(os.environ.get('NUM_HEADS', 8)) + mlp_mult = float(os.environ.get('MLP_MULT', 4.0)) + skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1'))) + tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1'))) + logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0)) + rope_base = float(os.environ.get('ROPE_BASE', 10000.0)) + rope_dims = int(os.environ.get('ROPE_DIMS', 16)) + rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048)) + ln_scale = bool(int(os.environ.get('LN_SCALE', '1'))) + qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 4.0)) + + # Layer looping + num_loops = int(os.environ.get('NUM_LOOPS', 2)) + loop_start = int(os.environ.get('LOOP_START', 4)) + loop_end = int(os.environ.get('LOOP_END', 5)) + enable_looping_at = float(os.environ.get('ENABLE_LOOPING_AT', 0.5)) + # NEW: Progressive recurrence curriculum + loop_phase2_at = float(os.environ.get('LOOP_PHASE2_AT', 0.0)) # 0 = disabled (use binary switch) + # NEW: Untie repeated MLPs + untie_loop_mlps = bool(int(os.environ.get('UNTIE_LOOP_MLPS', '0'))) + # Parallel residuals: MLP reads pre-attention input (GPT-J style) from this layer onward + parallel_residual_start: int = int(os.environ.get('PARALLEL_RESIDUAL_START', '-1')) # -1 = off + + # Optimizer + min_lr = float(os.environ.get('MIN_LR', 0.0)) + embed_lr = float(os.environ.get('EMBED_LR', 0.6)) + head_lr = float(os.environ.get('HEAD_LR', 0.008)) + tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03)) + tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005)) + matrix_lr = float(os.environ.get('MATRIX_LR', 0.02)) + scalar_lr = float(os.environ.get('SCALAR_LR', 0.02)) + muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99)) + muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5)) + muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92)) + muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500)) + muon_row_normalize = bool(int(os.environ.get('MUON_ROW_NORMALIZE', '1'))) + beta1 = float(os.environ.get('BETA1', 0.9)) + beta2 = float(os.environ.get('BETA2', 0.95)) + adam_eps = float(os.environ.get('ADAM_EPS', 1e-8)) + grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3)) + eval_stride = int(os.environ.get('EVAL_STRIDE', 64)) + muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95)) + adam_wd = float(os.environ.get('ADAM_WD', 0.02)) + muon_wd = float(os.environ.get('MUON_WD', 0.085)) + embed_wd = float(os.environ.get('EMBED_WD', 0.085)) + ema_decay = float(os.environ.get('EMA_DECAY', 0.997)) + + # Quantization & Compression + compressor = os.environ.get('COMPRESSOR', 'brotli') + gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64)) + gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 12.0)) + matrix_bits = int(os.environ.get('MATRIX_BITS', 6)) + embed_bits = int(os.environ.get('EMBED_BITS', 8)) + matrix_clip_sigmas = float(os.environ.get('MATRIX_CLIP_SIGMAS', 12.85)) + embed_clip_sigmas = float(os.environ.get('EMBED_CLIP_SIGMAS', 20.0)) + # NEW: Hessian-aware clipping + hessian_clip_lambda = float(os.environ.get('HESSIAN_CLIP_LAMBDA', 0.0)) # 0 = disabled + # Per-group clip multipliers (multiply matrix_clip_sigmas per block group) + # Informed by 3-seed Hessian analysis: early=317M, loop=128M, mid=50M, late=11M trace + clip_mult_early = float(os.environ.get('CLIP_MULT_EARLY', '1.0')) # blocks 0-2 + clip_mult_loop = float(os.environ.get('CLIP_MULT_LOOP', '1.0')) # blocks 4-5 + clip_mult_mid = float(os.environ.get('CLIP_MULT_MID', '1.0')) # blocks 3,6-7 + clip_mult_late = float(os.environ.get('CLIP_MULT_LATE', '1.0')) # blocks 8-10 + # NEW: Per-layer quant for looped layers + loop_layer_bits = int(os.environ.get('LOOP_LAYER_BITS', 0)) # 0 = use matrix_bits + loop_layer_clip_sigmas = float(os.environ.get('LOOP_LAYER_CLIP_SIGMAS', 0)) # 0 = use matrix_clip_sigmas + + # Discriminative TTT (score-first, legal per Issue #1082) + ttt_enabled = bool(int(os.environ.get('TTT_ENABLED', '0'))) + ttt_lr = float(os.environ.get('TTT_LR', '0.0005')) + ttt_epochs = int(os.environ.get('TTT_EPOCHS', '4')) + ttt_chunk_tokens = int(os.environ.get('TTT_CHUNK_TOKENS', '32768')) + ttt_freeze_blocks = int(os.environ.get('TTT_FREEZE_BLOCKS', '2')) + ttt_ns_steps = int(os.environ.get('TTT_NS_STEPS', '3')) + ttt_entropy_high = float(os.environ.get('TTT_ENTROPY_HIGH', '2.1')) + ttt_entropy_low = float(os.environ.get('TTT_ENTROPY_LOW', '1.75')) + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + is_main_process = rank == 0 + grad_accum_steps = 8 // world_size + + # Data paths + datasets_dir = os.path.join(data_dir, 'datasets', f'fineweb10B_sp{vocab_size}') + train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin') + val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin') + tokenizer_path = os.path.join(data_dir, 'tokenizers', f'fineweb_{vocab_size}_bpe.model') + + # Experiment files + logfile = f"logs/{run_id}.txt" + model_path = "final_model.pt" + quantized_model_path = "final_model.int6.ptz" + +# ---------------------------------------- +# Global Logging Function +# ---------------------------------------- + +_logger_hparams = None + + +def set_logging_hparams(h: Hyperparameters) -> None: + global _logger_hparams + _logger_hparams = h + + +def log(msg, console: bool = True) -> None: + if _logger_hparams is None: + print(msg) + return + if _logger_hparams.is_main_process: + if console: + print(msg) + if _logger_hparams.logfile is not None: + with open(_logger_hparams.logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + +# ---------------------------------------- +# Data Loading +# ---------------------------------------- + +class ValidationData: + def __init__(self, h: Hyperparameters, device: torch.device): + self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path) + if int(self.sp.vocab_size()) != h.vocab_size: + raise ValueError( + f"VOCAB_SIZE={h.vocab_size} does not match tokenizer vocab_size={int(self.sp.vocab_size())}" + ) + self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len) + self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = ( + build_sentencepiece_luts(self.sp, h.vocab_size, device)) + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + assert sp.piece_to_id("\u2581") != sp.unk_id(), \ + "Tokenizer must have '▁' (space) as its own token for correct BPB byte counting" + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" int: + key = str(file) + cached = _SHARD_NTOKENS_CACHE.get(key) + if cached is not None: + return cached + header = np.fromfile(file, dtype=" np.memmap: + key = str(file) + mm = _MMAP_CACHE.get(key) + if mm is not None: + return mm + n = _read_num_tokens(file) + mm = np.memmap(file, mode="r", dtype=" None: + max_phase = min(self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1)) + phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0 + num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len + sequence_order = self.rng.permutation(num_sequences) + self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist() + + def next_batch(self, global_tokens: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + device_tokens = global_tokens // (self.world_size * grad_accum_steps) + device_batch_size = device_tokens // self.seq_len + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64) + for bi in range(device_batch_size): + total = remaining.sum() + if total <= 0: + for si in range(len(self.files)): + self._reset_shard(si) + remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64) + total = remaining.sum() + probs = remaining / total + si = int(self.rng.choice(len(self.files), p=probs)) + start_ind = self.start_inds[si].pop() + remaining[si] -= 1 + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor( + np.array(mm[start_ind:start_ind + self.seq_len + 1], dtype=np.int64)) + x[bi] = window[:-1] + y[bi] = window[1:] + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ---------------------------------------- +# Model Architecture +# ---------------------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange( + 0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + rope_base: float, qk_gain_init: float, train_seq_len: int): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + rope_base: float, qk_gain_init: float, train_seq_len: int, + layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention( + dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + self.parallel = False # set by GPT.__init__ if PARALLEL_RESIDUAL_START + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor) + if self.parallel: + # GPT-J style: MLP reads pre-attention input, both added in parallel + mlp_out = self.mlp(self.mlp_norm(x_in) * self.ln_scale_factor) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out \ + + self.mlp_scale.to(dtype=x_in.dtype)[None, None, :] * mlp_out + else: + # Standard sequential: MLP reads post-attention output + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp( + self.mlp_norm(x_out) * self.ln_scale_factor) + return x_out + + +class GPT(nn.Module): + def __init__(self, h: Hyperparameters): + super().__init__() + if h.logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {h.logit_softcap}") + self.tie_embeddings = h.tie_embeddings + self.tied_embed_init_std = h.tied_embed_init_std + self.logit_softcap = h.logit_softcap + self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim) + if h.embedding_dim != h.model_dim: + self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False) + self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False) + else: + self.embed_proj = None + self.head_proj = None + self.num_encoder_layers = h.num_layers // 2 + self.num_decoder_layers = h.num_layers - self.num_encoder_layers + self.blocks = nn.ModuleList([ + Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base, + h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale) + for i in range(h.num_layers) + ]) + if h.rope_dims > 0: + head_dim = h.model_dim // h.num_heads + for block in self.blocks: + block.attn.rope_dims = h.rope_dims + block.attn.rotary = Rotary(head_dim, base=h.rope_base, train_seq_len=h.train_seq_len, rope_dims=h.rope_dims) + self.final_norm = RMSNorm() + self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if h.xsa_last_n > 0: + for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers): + self.blocks[i].attn.use_xsa = True + if h.parallel_residual_start >= 0: + for i in range(h.parallel_residual_start, h.num_layers): + self.blocks[i].parallel = True + + # Layer looping — Clark-style: single boolean + mutable index lists + self.looping_active = False + self.loop_start = h.loop_start + self.loop_end = h.loop_end + + # Phase 0: no looping (standard enc/dec) + self.phase0_enc = list(range(self.num_encoder_layers)) + self.phase0_dec = list(range(self.num_encoder_layers, h.num_layers)) + + if h.num_loops > 0: + loop_seg = list(range(h.loop_start, h.loop_end + 1)) + + # Phase 2: full looping (original Clark behavior) + all2 = list(range(h.loop_start)) + for _ in range(h.num_loops + 1): + all2.extend(loop_seg) + all2.extend(range(h.loop_end + 1, h.num_layers)) + ne2 = len(all2) // 2 + self.phase2_enc = all2[:ne2] + self.phase2_dec = all2[ne2:] + + # Phase 1: one repeat only (progressive curriculum) + all1 = list(range(h.loop_start)) + for _ in range(min(h.num_loops, 1) + 1): + all1.extend(loop_seg) + all1.extend(range(h.loop_end + 1, h.num_layers)) + ne1 = len(all1) // 2 + self.phase1_enc = all1[:ne1] + self.phase1_dec = all1[ne1:] + else: + self.phase1_enc = self.phase0_enc + self.phase1_dec = self.phase0_dec + self.phase2_enc = self.phase0_enc + self.phase2_dec = self.phase0_dec + + # Active indices (mutated by activate_looping) + self.encoder_indices = self.phase2_enc + self.decoder_indices = self.phase2_dec + + # Skip weights sized for max phase (phase 2) + self.num_skip_weights = min(len(self.phase2_enc), len(self.phase2_dec)) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) if h.skip_gates_enabled else None + + # Untied loop MLPs: separate MLP weights per repeat pass (attention stays shared) + self.override_mlps = None + if h.untie_loop_mlps and h.num_loops > 0: + override = {} + for layer_idx in range(h.loop_start, h.loop_end + 1): + for pass_num in range(1, h.num_loops + 1): + override[f"{layer_idx}_p{pass_num}"] = MLP(h.model_dim, h.mlp_mult) + self.override_mlps = nn.ModuleDict(override) + + self._init_weights() + + def activate_looping(self, phase: int = 2): + """Switch looping phase. 0=off, 1=one repeat, 2=full repeats.""" + if phase == 0: + self.looping_active = False + elif phase == 1: + self.encoder_indices = self.phase1_enc + self.decoder_indices = self.phase1_dec + self.looping_active = True + else: + self.encoder_indices = self.phase2_enc + self.decoder_indices = self.phase2_dec + self.looping_active = True + + def _block_forward_with_mlp(self, block: 'Block', x: Tensor, x0: Tensor, mlp: nn.Module) -> Tensor: + """Run a block using shared attention but an override MLP.""" + mix = block.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = block.attn(block.attn_norm(x_in) * block.ln_scale_factor) + if block.parallel: + mlp_out = mlp(block.mlp_norm(x_in) * block.ln_scale_factor) + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out \ + + block.mlp_scale.to(dtype=x_in.dtype)[None, None, :] * mlp_out + else: + x_out = x_in + block.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + block.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * mlp( + block.mlp_norm(x_out) * block.ln_scale_factor) + return x_out + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif (module.weight.ndim == 2 and module.weight.shape[0] >= 64 and + module.weight.shape[1] >= 64): + nn.init.orthogonal_(module.weight, gain=1.0) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.embed_proj is not None: + x = self.embed_proj(x) + x0 = x + skips: list[Tensor] = [] + enc_iter = self.encoder_indices if self.looping_active else range(self.num_encoder_layers) + dec_iter = self.decoder_indices if self.looping_active else range(self.num_encoder_layers, self.num_encoder_layers + self.num_decoder_layers) + pass_count: dict[int, int] = {} + for i in enc_iter: + pc = pass_count.get(i, 0) + if pc > 0 and self.override_mlps is not None: + x = self._block_forward_with_mlp(self.blocks[i], x, x0, self.override_mlps[f"{i}_p{pc}"]) + else: + x = self.blocks[i](x, x0) + pass_count[i] = pc + 1 + skips.append(x) + for skip_idx, i in enumerate(dec_iter): + if skip_idx < self.num_skip_weights and skips: + scaled_skip = self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + if self.skip_gates is not None: + g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :] + x = torch.lerp(scaled_skip, x, g) + else: + x = x + scaled_skip + pc = pass_count.get(i, 0) + if pc > 0 and self.override_mlps is not None: + x = self._block_forward_with_mlp(self.blocks[i], x, x0, self.override_mlps[f"{i}_p{pc}"]) + else: + x = self.blocks[i](x, x0) + pass_count[i] = pc + 1 + x = self.final_norm(x) + if self.head_proj is not None: + x = self.head_proj(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), target_ids.reshape(-1), reduction="mean") + + +def classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name or "override_mlps" in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +# ---------------------------------------- +# Optimization +# ---------------------------------------- + +# Polar Express: minimax-optimal per-iteration NS coefficients (arXiv:2505.16932) +# Computed via Remez algorithm with l=1e-3, safety_factor=1.01, cushion=0.02 +POLAR_EXPRESS_COEFFS = [ + (8.237312490495558, -23.157747414558205, 16.680568411445918), + (4.082441999064834, -2.893047735332586, 0.525284925697565), + (3.926347992254655, -2.854746803476530, 0.531802242289499), + (3.298218713308514, -2.424541981026706, 0.486320083588441), + (2.320007312889811, -1.686216972996763, 0.420680273402352), +] +MUON_FIXED_COEFFS = (3.4445, -4.7750, 2.0315) +USE_POLAR_EXPRESS = bool(int(os.environ.get('POLAR_EXPRESS', '0'))) + +@torch.compile +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + if USE_POLAR_EXPRESS: + for i in range(steps): + a, b, c = POLAR_EXPRESS_COEFFS[i] if i < len(POLAR_EXPRESS_COEFFS) else POLAR_EXPRESS_COEFFS[-1] + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + else: + a, b, c = MUON_FIXED_COEFFS + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0, + row_normalize: bool = False): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay, + row_normalize=row_normalize), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + if group.get("row_normalize", False): + row_norms = g.float().norm(dim=-1, keepdim=True).clamp_min(1e-07) + g = g / row_norms.to(g.dtype) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates", + ).split(",") + if pattern +) + + +class Optimizers(): + def __init__(self, h: Hyperparameters, base_model: GPT): + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in + CONTROL_TENSOR_NAME_PATTERNS) + ] + # Add override MLP params (untied loop MLPs) to Muon + if base_model.override_mlps is not None: + for name, p in base_model.override_mlps.named_parameters(): + if p.ndim == 2: + matrix_params.append(p) + else: + scalar_params.append(p) + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + + token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + self.optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.embed_wd, + fused=True, + ) + self.optimizer_muon = Muon( + matrix_params, + lr=h.matrix_lr, + momentum=h.muon_momentum, + backend_steps=h.muon_backend_steps, + weight_decay=h.muon_wd, + row_normalize=h.muon_row_normalize, + ) + for group in self.optimizer_muon.param_groups: + group["base_lr"] = h.matrix_lr + self.optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": h.scalar_lr, "base_lr": h.scalar_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + weight_decay=h.adam_wd, + fused=True, + ) + self.optimizers = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar] + if base_model.lm_head is not None: + self.optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": h.head_lr, "base_lr": h.head_lr}], + betas=(h.beta1, h.beta2), + eps=h.adam_eps, + fused=True, + ) + self.optimizers.insert(1, self.optimizer_head) + else: + self.optimizer_head = None + + def __iter__(self): + return iter(self.optimizers) + + def zero_grad_all(self) -> None: + for opt in self.optimizers: + opt.zero_grad(set_to_none=True) + + def step(self): + for opt in self.optimizers: + opt.step() + self.zero_grad_all() + +# ---------------------------------------- +# Quantization +# ---------------------------------------- + +def restore_fp32_params(model: nn.Module) -> None: + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, param in model.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +def collect_hessians( + model: nn.Module, + train_loader: ShuffledSequenceLoader, + h: Hyperparameters, + device: torch.device, + n_calibration_batches: int = 64, +) -> dict[str, Tensor]: + hessians: dict[str, Tensor] = {} + hooks = [] + + def make_hook(name: str): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + + for name, module in model.named_modules(): + if isinstance(module, CastedLinear) and module.weight.numel() > 65536: + cat = classify_param(name + ".weight") + if cat in ("mlp", "attn"): + hooks.append(module.register_forward_hook(make_hook(name + ".weight"))) + + if model.tie_embeddings: + hook_module = model.head_proj if model.head_proj is not None else model.final_norm + def make_output_hook(name: str): + def hook_fn(module, inp, out): + x = out.detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros( + x.shape[1], x.shape[1], dtype=torch.float32, device=device + ) + hessians[name].addmm_(x.T, x) + return hook_fn + hooks.append(hook_module.register_forward_hook(make_output_hook("tok_emb.weight"))) + + model.eval() + with torch.no_grad(): + for _ in range(n_calibration_batches): + x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + model.forward_logits(x) + + for hook in hooks: + hook.remove() + + for name in hessians: + hessians[name] = hessians[name].cpu() / n_calibration_batches + + return hessians + + +def gptq_quantize_weight( + w: Tensor, + H: Tensor, + clip_sigmas: float = 3.0, + clip_range: int = 63, + block_size: int = 128, + hessian_clip_lambda: float = 0.0, +) -> tuple[Tensor, Tensor]: + W_orig = w.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * H.diag().mean() + H.diagonal().add_(damp) + + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + + # SDClip: c = clip_sigmas * std(row) + row_std = W_orig.std(dim=1) + + # Hessian-aware SDClip: per-row modulation using GPTQ Hessian diagonal + if hessian_clip_lambda > 0.0: + diagH = torch.diag(H[invperm][:, invperm]).clamp_min(1e-8) + col_importance = diagH / diagH.mean() + row_importance = (W_orig.abs() * col_importance.unsqueeze(0)).mean(dim=1) + row_importance = row_importance / row_importance.mean() + adj = 1.0 + hessian_clip_lambda * (row_importance - 1.0) + s = (clip_sigmas * row_std * adj / clip_range).clamp_min(1e-10).to(torch.float16) + else: + s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16) + + sf = s.float() + + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + + return Q[:, invperm], s + + +def _is_loop_layer(name: str, h: Hyperparameters) -> bool: + """Check if a param name belongs to a looped layer.""" + if "override_mlps" in name: + return True # override MLPs are always loop layers + m = re.search(r'blocks\.(\d+)\.', name) + if m: + idx = int(m.group(1)) + return h.loop_start <= idx <= h.loop_end + return False + + +def _get_group_clip_mult(name: str, h: Hyperparameters) -> float: + """Get per-group clip multiplier for a weight matrix.""" + m = re.search(r'blocks\.(\d+)\.', name) + if not m: + return 1.0 + idx = int(m.group(1)) + if idx <= 2: + return h.clip_mult_early + if h.loop_start <= idx <= h.loop_end: + return h.clip_mult_loop + if idx in (3, 6, 7): + return h.clip_mult_mid + if idx >= 8: + return h.clip_mult_late + return 1.0 + + +def gptq_mixed_quantize( + state_dict: dict[str, Tensor], + hessians: dict[str, Tensor], + h: Hyperparameters, +) -> tuple[dict[str, Tensor], dict[str, object]]: + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough (float16)" + continue + + # Determine bits and clip sigmas + if "tok_emb" in name: + cs = h.embed_clip_sigmas + bits = h.embed_bits + elif _is_loop_layer(name, h) and h.loop_layer_bits > 0: + cs = h.loop_layer_clip_sigmas if h.loop_layer_clip_sigmas > 0 else h.matrix_clip_sigmas + bits = h.loop_layer_bits + log(f" loop_layer_quant: {name} -> int{bits} clip={cs:.2f}") + else: + cs = h.matrix_clip_sigmas + bits = h.matrix_bits + + # Per-group clip multiplier (from 3-seed Hessian analysis) + group_mult = _get_group_clip_mult(name, h) + if group_mult != 1.0: + cs = cs * group_mult + + q, s = gptq_quantize_weight( + t, hessians[name], clip_sigmas=cs, clip_range=2**(bits - 1) - 1, + hessian_clip_lambda=h.hessian_clip_lambda) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = f"gptq (int{bits})" + + categories = collections.defaultdict(set) + for name, cat in meta.items(): + short = re.sub(r'\.\d+$', '', re.sub(r'blocks\.\d+', 'blocks', name)) + categories[cat].add(short) + log("Quantized weights:") + for cat in sorted(categories): + log(f" {cat}: {', '.join(sorted(categories[cat]))}") + + return result, meta + + +def dequantize_mixed(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if "passthrough" in info: + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +_BSHF_MAGIC = b"BSHF" + + +def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: + if stride <= 1 or len(data) < stride: + return data + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] + out[dest_off:dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data: bytes) -> bytes: + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data + stride = data[4] + if stride < 2: + return data[5:] + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +def _compress(data: bytes, compressor: str) -> bytes: + data = _byte_shuffle(data) + if compressor == "lzma": + return lzma.compress(data, preset=6) + elif compressor == "brotli": + import brotli + return brotli.compress(data, quality=11) + raise ValueError(f"Unknown compressor: {compressor!r}") + + +def _decompress(data: bytes, compressor: str) -> bytes: + if compressor == "lzma": + raw = lzma.decompress(data) + elif compressor == "brotli": + import brotli + raw = brotli.decompress(data) + else: + raise ValueError(f"Unknown compressor: {compressor!r}") + raw = _byte_unshuffle(raw) + return raw + + +def serialize(h: Hyperparameters, base_model: torch.nn.Module, code: str) -> tuple[int, int]: + code_bytes = len(code.encode("utf-8")) + if h.is_main_process: + torch.save(base_model.state_dict(), h.model_path) + model_bytes = os.path.getsize(h.model_path) + log(f"Serialized model: {model_bytes} bytes") + log(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + device = torch.device("cuda", h.local_rank) + log("GPTQ:collecting Hessians from calibration data...") + t0 = time.perf_counter() + calib_loader = ShuffledSequenceLoader(h, device) + hessians = collect_hessians( + base_model, calib_loader, h, device, + n_calibration_batches=h.gptq_calibration_batches, + ) + log(f"GPTQ:collected {len(hessians)} Hessians in {time.perf_counter() - t0:.1f}s") + + # Dump Hessian diagnostics for offline optimal allocation analysis + if h.is_main_process and h.hessian_clip_lambda > 0: + diag_path = os.path.join(os.path.dirname(h.model_path), "hessian_diagnostics.pt") + diagnostics = {} + for name, H_mat in hessians.items(): + H_f = H_mat.float() + diagH = H_f.diag().clamp_min(1e-8) + col_imp = diagH / diagH.mean() + w = sd_cpu[name].float() + row_imp = (w.abs() * col_imp.unsqueeze(0)).mean(dim=1) + row_imp = row_imp / row_imp.mean() + row_std = w.std(dim=1) + diagnostics[name] = { + "hessian_diag": diagH.cpu(), + "col_importance": col_imp.cpu(), + "row_importance": row_imp.cpu(), + "row_std": row_std.cpu(), + "hessian_trace": H_f.trace().item(), + "shape": list(w.shape), + } + torch.save(diagnostics, diag_path) + log(f"GPTQ:saved Hessian diagnostics to {diag_path} ({len(diagnostics)} matrices)") + + quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h) + + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = _compress(quant_raw, h.compressor) + quant_file_bytes = len(quant_blob) + bytes_total = quant_file_bytes + code_bytes + if h.is_main_process: + with open(h.quantized_model_path, "wb") as f: + f.write(quant_blob) + log(f"Serialized model quantized+{h.compressor}: {quant_file_bytes} bytes") + log(f"Total submission size quantized+{h.compressor}: {bytes_total} bytes") + return bytes_total, quant_file_bytes + + +def deserialize(h: Hyperparameters, device: torch.device) -> GPT: + eval_model = GPT(h).to(device).bfloat16() + restore_fp32_params(eval_model) + sd_cpu = {k: v.detach().cpu() for k, v in eval_model.state_dict().items()} + + with open(h.quantized_model_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(_decompress(quant_blob_disk, h.compressor)), + map_location="cpu", + ) + deq_state = dequantize_mixed(quant_state["w"], quant_state["m"], sd_cpu) + eval_model.load_state_dict(deq_state, strict=True) + + return eval_model + +# ---------------------------------------- +# Evaluation +# ---------------------------------------- + +def _loss_bpb(loss_sum, token_count, byte_count) -> tuple[float, float]: + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + return val_loss, val_bpb + + +def eval_val( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + model: nn.Module +) -> tuple[float, float]: + seq_len = h.eval_seq_len + local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={h.val_batch_tokens}, WORLD_SIZE={h.world_size}, " + f"GRAD_ACCUM_STEPS={h.grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_data.val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * h.rank) // h.world_size + seq_end = (total_seqs * (h.rank + 1)) // h.world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_data.val_tokens[raw_start:raw_end].to( + device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (val_data.has_leading_space_lut[tgt_ids] & + ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + model.train() + return _loss_bpb(val_loss_sum, val_token_count, val_byte_count) + + +def eval_val_sliding( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + base_model: nn.Module, + batch_seqs: int = 32 +) -> tuple[float, float]: + base_model.eval() + logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + seq_len = h.eval_seq_len + context_size = seq_len - h.eval_stride + total_tokens = val_data.val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) + if ws + context_size < total_tokens] + + total_windows = len(window_starts) + my_s = (total_windows * h.rank) // h.world_size + my_e = (total_windows * (h.rank + 1)) // h.world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = logits_fn(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & + ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def eval_val_ttt( + h: Hyperparameters, + device: torch.device, + val_data: ValidationData, + base_model: nn.Module, + batch_seqs: int = 32, +) -> tuple[float, float]: + """Score-first discriminative TTT (legal per Issue #1082). + + Protocol per chunk: + Phase 1: SCORE chunk in inference_mode (immutable BPB contribution) + Phase 2: TRAIN on scored chunk so future chunks benefit + """ + if not h.ttt_enabled: + return eval_val_sliding(h, device, val_data, base_model, batch_seqs) + + seq_len = h.eval_seq_len + stride = h.eval_stride + context_size = seq_len - stride + total_tokens = val_data.val_tokens.numel() - 1 + chunk_size = h.ttt_chunk_tokens + + # Build sliding windows (same as eval_val_sliding) + window_starts = [ws for ws in range(0, total_tokens, stride) + if ws + context_size < total_tokens] + total_windows = len(window_starts) + + # Chunk windows for TTT + chunk_boundaries = list(range(0, total_windows, max(1, chunk_size // stride))) + if chunk_boundaries[-1] != total_windows: + chunk_boundaries.append(total_windows) + chunk_windows = [window_starts[chunk_boundaries[c]:chunk_boundaries[c + 1]] + for c in range(len(chunk_boundaries) - 1)] + + log(f"ttt: {len(chunk_windows)} chunks, {total_windows} windows, " + f"freeze_blocks={h.ttt_freeze_blocks}, lr={h.ttt_lr}, epochs={h.ttt_epochs}") + + # Freeze early blocks, unfreeze rest + for p in base_model.parameters(): + p.requires_grad = False + ttt_params = [] + for i, block in enumerate(base_model.blocks): + if i >= h.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad = True + ttt_params.append(p) + for p in base_model.final_norm.parameters(): + p.requires_grad = True + ttt_params.append(p) + + saved_state = {id(p): p.data.clone() for p in ttt_params} + + # Per-layer LR groups: projections adapt faster, FCs slower + proj_params, fc_params, other_params = [], [], [] + for name, p in base_model.named_parameters(): + if not p.requires_grad: + continue + if ".proj." in name or name.endswith(".proj.weight"): + proj_params.append(p) + elif ".fc." in name or name.endswith(".fc.weight"): + fc_params.append(p) + else: + other_params.append(p) + + ttt_lr = h.ttt_lr + param_groups = [pg for pg in [ + {"params": proj_params, "lr": ttt_lr * 3.0, "_lr_mult": 3.0} if proj_params else None, + {"params": fc_params, "lr": ttt_lr * 0.5, "_lr_mult": 0.5} if fc_params else None, + {"params": other_params, "lr": ttt_lr, "_lr_mult": 1.0} if other_params else None, + ] if pg is not None] + optimizer = torch.optim.SGD(param_groups, lr=ttt_lr, momentum=0.9) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + for ci, windows in enumerate(chunk_windows): + if not windows: + continue + + my_s = (len(windows) * h.rank) // h.world_size + my_e = (len(windows) * (h.rank + 1)) // h.world_size + my_windows = windows[my_s:my_e] + + # === Phase 1: SCORE chunk (inference-only, contributes to official BPB) === + base_model.eval() + chunk_entropy = 0.0 + chunk_count = 0 + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + wlens.append(wlen) + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch).float() + + # Track NLL for adaptive epochs + for i in range(bsz): + wl = wlens[i] + if wl > 0: + nll = F.cross_entropy(logits[i, :wl], y_batch[i, :wl], reduction='sum') + chunk_entropy += nll.item() + chunk_count += wl + + # Score (same sliding window logic as eval_val_sliding) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else context_size + scored_nll = F.cross_entropy( + logits[i, s:wlen], y_batch[i, s:wlen], reduction='none' + ).to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = val_data.base_bytes_lut[tgt].to(torch.float64) + tb += (val_data.has_leading_space_lut[tgt] & + ~val_data.is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # === Phase 2: TRAIN on scored chunk (entropy-adaptive epochs) === + if chunk_count > 0: + chunk_nll_avg = chunk_entropy / chunk_count + if chunk_nll_avg > h.ttt_entropy_high: + epochs = h.ttt_epochs + 1 + elif chunk_nll_avg < h.ttt_entropy_low: + epochs = max(h.ttt_epochs - 1, 1) + else: + epochs = h.ttt_epochs + else: + epochs = h.ttt_epochs + + base_model.train() + for _epoch in range(epochs): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + for i, ws in enumerate(batch_ws): + we = min(ws + seq_len, total_tokens) + wlen = we - ws + chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + optimizer.zero_grad() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ttt_logits = base_model.forward_logits(x_batch) + ttt_loss = F.cross_entropy( + ttt_logits.reshape(-1, ttt_logits.size(-1)), + y_batch.reshape(-1), + ) + ttt_loss.backward() + + # NS-orthogonalized gradients for 2D params (Muon-style) + if h.ttt_ns_steps > 0: + with torch.no_grad(): + for pg in param_groups: + for p in pg["params"]: + if p.grad is None: + continue + g = p.grad.detach().float() + if g.ndim == 2: + g = zeropower_via_newtonschulz5(g, steps=h.ttt_ns_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + p.data.add_(g.to(p.dtype), alpha=-ttt_lr * pg["_lr_mult"]) + p.grad = None + optimizer.step() + + if ci % 10 == 0 or ci == len(chunk_windows) - 1: + log(f"ttt_chunk: {ci + 1}/{len(chunk_windows)} epochs={epochs}") + + # Restore original weights (TTT adapted weights don't persist) + with torch.no_grad(): + for p in ttt_params: + p.data.copy_(saved_state[id(p)]) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + for p in base_model.parameters(): + p.requires_grad = True + base_model.train() + return _loss_bpb(loss_sum, token_count, byte_count) + + +def timed_eval(label: str, fn, *args, **kwargs) -> tuple[float, float]: + torch.cuda.synchronize() + t0 = time.perf_counter() + val_loss, val_bpb = fn(*args, **kwargs) + torch.cuda.synchronize() + elapsed_ms = 1000.0 * (time.perf_counter() - t0) + log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms") + return val_loss, val_bpb + + +# ----------------------------- +# Training +# ----------------------------- + +def train_model(h: Hyperparameters, device: torch.device, val_data: ValidationData): + base_model = GPT(h).to(device).bfloat16() + restore_fp32_params(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + if h.distributed: + model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False) + else: + model = compiled_model + log(f"model_params:{sum(p.numel() for p in base_model.parameters())}") + + # Log new features + if h.hessian_clip_lambda > 0: + log(f"hessian_clip: lambda={h.hessian_clip_lambda}") + if h.untie_loop_mlps: + log(f"untied_loop_mlps: ON (override MLPs for layers {h.loop_start}-{h.loop_end})") + if h.parallel_residual_start >= 0: + log(f"parallel_residuals: ON (layers {h.parallel_residual_start}-{h.num_layers-1})") + if h.loop_layer_bits > 0: + log(f"loop_layer_quant: bits={h.loop_layer_bits} clip={h.loop_layer_clip_sigmas}") + if h.loop_phase2_at > 0: + log(f"progressive_recurrence: phase1={h.enable_looping_at} phase2={h.loop_phase2_at}") + has_group_mults = any(getattr(h, f'clip_mult_{g}') != 1.0 for g in ('early', 'loop', 'mid', 'late')) + if has_group_mults: + log(f"per_group_clip: early={h.clip_mult_early} loop={h.clip_mult_loop} mid={h.clip_mult_mid} late={h.clip_mult_late}") + + optimizers = Optimizers(h, base_model) + train_loader = ShuffledSequenceLoader(h, device) + + max_wallclock_ms = 1000.0 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None + if max_wallclock_ms is not None: + max_wallclock_ms -= h.gptq_reserve_seconds * 1000.0 + log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms") + + def training_frac(step: int, elapsed_ms: float) -> float: + if max_wallclock_ms is None: + return step / max(h.iterations, 1) + return elapsed_ms / max(max_wallclock_ms, 1e-9) + + def lr_mul(frac: float) -> float: + if h.warmdown_frac <= 0: + return 1.0 + if frac >= 1.0 - h.warmdown_frac: + return max((1.0 - frac) / h.warmdown_frac, h.min_lr) + return 1.0 + + def step_fn(step, lr_scale): + optimizers.zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(h.grad_accum_steps): + if h.distributed: + model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1 + x, y = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss / h.grad_accum_steps).backward() + train_loss /= h.grad_accum_steps + + frac = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * h.muon_momentum_warmup_start + frac * h.muon_momentum + for group in optimizers.optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * lr_scale + + if h.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm) + + optimizers.step() + return train_loss + + # Model warmup — also warmup looped path + if h.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() + for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"warmup_step: {warmup_step + 1}/{h.warmup_steps}") + if h.num_loops > 0: + if h.loop_phase2_at > 0: + # Progressive: warmup both phase1 and phase2 + base_model.activate_looping(1) + log(f"loop_warmup_phase1: encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"loop_warmup_p1_step: {warmup_step + 1}/{h.warmup_steps}") + base_model.activate_looping(2) + log(f"loop_warmup_phase2: encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"loop_warmup_p2_step: {warmup_step + 1}/{h.warmup_steps}") + else: + # Original Clark behavior: warmup full loop config + base_model.activate_looping(2) + log(f"loop_warmup: encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}") + for warmup_step in range(h.warmup_steps): + step_fn(warmup_step, 1.0) + if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps: + log(f"loop_warmup_step: {warmup_step + 1}/{h.warmup_steps}") + base_model.activate_looping(0) + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + optimizers.zero_grad_all() + if h.distributed: + model.require_backward_grad_sync = True + train_loader = ShuffledSequenceLoader(h, device) + + # Training loop + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = h.ema_decay + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val(h, device, val_data, model) + log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < h.iterations: + log( + f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms " + f"step: {step}/{h.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + frac = training_frac(step, elapsed_ms) + scale = lr_mul(frac) + + # Progressive recurrence curriculum + if h.num_loops > 0: + if h.loop_phase2_at > 0: + if not base_model.looping_active and frac >= h.enable_looping_at: + base_model.activate_looping(1) + log(f"layer_loop:phase1 step:{step} frac:{frac:.3f}") + elif base_model.looping_active and base_model.encoder_indices is base_model.phase1_enc and frac >= h.loop_phase2_at: + base_model.activate_looping(2) + log(f"layer_loop:phase2 step:{step} frac:{frac:.3f}") + else: + if not base_model.looping_active and frac >= h.enable_looping_at: + base_model.activate_looping(2) + log(f"layer_loop:enabled step:{step} frac:{frac:.3f}") + + train_loss = step_fn(step, scale) + + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + should_log_train = ( + h.train_log_every > 0 + and (step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1000.0) + log( + f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} " + f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f} " + f"looping:{base_model.looping_active}" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if h.distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Weight averaging + log("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + return base_model, compiled_model + + +def train_and_eval(h: Hyperparameters, device: torch.device) -> None: + random.seed(h.seed) + np.random.seed(h.seed) + torch.manual_seed(h.seed) + torch.cuda.manual_seed_all(h.seed) + + val_data = ValidationData(h, device) + log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}") + log(f"val_tokens: {val_data.val_tokens.numel() - 1}") + + base_model, compiled_model = train_model(h, device, val_data) + torch._dynamo.reset() + timed_eval("pre-quantization post-ema", eval_val, h, device, val_data, compiled_model) + + serialize(h, base_model, Path(__file__).read_text(encoding="utf-8")) + if h.distributed: + dist.barrier() + eval_model = deserialize(h, device) + if h.num_loops > 0: + eval_model.activate_looping(2) # Full looping for eval + + compiled_model = torch.compile(eval_model, dynamic=False, fullgraph=True) + timed_eval("quantized", eval_val, h, device, val_data, compiled_model) + if h.sliding_window_enabled: + timed_eval("quantized_sliding_window", eval_val_sliding, h, device, val_data, eval_model) + if h.ttt_enabled: + timed_eval("quantized_ttt", eval_val_ttt, h, device, val_data, eval_model) + + +def main(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + torch._dynamo.config.optimize_ddp = False + + h = Hyperparameters() + set_logging_hparams(h) + if h.is_main_process: + os.makedirs("logs", exist_ok=True) + log(100 * "=", console=False) + log("Hyperparameters:", console=True) + for k, v in sorted(vars(type(h)).items()): + if not k.startswith("_"): + log(f" {k}: {v}", console=True) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running PyTorch {torch.__version__}", console=False) + log( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, + text=True, check=False).stdout, + console=False, + ) + log("=" * 100, console=False) + + train_and_eval(h, device) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_seed1337.log b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_seed1337.log new file mode 100644 index 0000000000..14fce6e5cc --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_seed1337.log @@ -0,0 +1,165 @@ +W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] +W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] ***************************************** +W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.997 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.5 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + hessian_clip_lambda: 0.175 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/f3971278-d577-499b-8fde-755434809ba9.txt + logit_softcap: 30.0 + loop_end: 5 + loop_layer_bits: 0 + loop_layer_clip_sigmas: 0.0 + loop_phase2_at: 0.65 + loop_start: 4 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.085 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 4.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: f3971278-d577-499b-8fde-755434809ba9 + scalar_lr: 0.02 + seed: 1337 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 32768 + ttt_enabled: False + ttt_entropy_high: 2.1 + ttt_entropy_low: 1.75 + ttt_epochs: 4 + ttt_freeze_blocks: 2 + ttt_lr: 0.0005 + ttt_ns_steps: 3 + untie_loop_mlps: False + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40540160 +model_params:35943512 +hessian_clip: lambda=0.175 +parallel_residuals: ON (layers 7-10) +progressive_recurrence: phase1=0.5 phase2=0.65 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup_phase1: encoder:[0, 1, 2, 3, 4, 5] decoder:[4, 5, 6, 7, 8, 9, 10] +loop_warmup_p1_step: 1/20 +loop_warmup_p1_step: 2/20 +loop_warmup_p1_step: 3/20 +loop_warmup_p1_step: 4/20 +loop_warmup_p1_step: 5/20 +loop_warmup_p1_step: 6/20 +loop_warmup_p1_step: 10/20 +loop_warmup_p1_step: 20/20 +loop_warmup_phase2: encoder:[0, 1, 2, 3, 4, 5, 4] decoder:[5, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_p2_step: 1/20 +loop_warmup_p2_step: 2/20 +loop_warmup_p2_step: 3/20 +loop_warmup_p2_step: 4/20 +loop_warmup_p2_step: 5/20 +loop_warmup_p2_step: 6/20 +loop_warmup_p2_step: 10/20 +loop_warmup_p2_step: 20/20 +0/20000 val_loss: 9.0052 val_bpb: 3.4862 +1/20000 train_loss: 9.0086 train_time: 0.0m tok/s: 8101558 looping:False +2/20000 train_loss: 12.1911 train_time: 0.0m tok/s: 8047236 looping:False +3/20000 train_loss: 11.0242 train_time: 0.0m tok/s: 7763017 looping:False +4/20000 train_loss: 9.5010 train_time: 0.0m tok/s: 7744899 looping:False +5/20000 train_loss: 8.3911 train_time: 0.0m tok/s: 7764479 looping:False +500/20000 train_loss: 3.3106 train_time: 0.8m tok/s: 7721577 looping:False +1000/20000 train_loss: 3.2012 train_time: 1.7m tok/s: 7711891 looping:False +1500/20000 train_loss: 3.1827 train_time: 2.5m tok/s: 7711479 looping:False +2000/20000 train_loss: 2.9936 train_time: 3.4m tok/s: 7711845 looping:False +2500/20000 train_loss: 3.0679 train_time: 4.2m tok/s: 7713114 looping:False +layer_loop:phase1 step:2884 frac:0.500 +3000/20000 train_loss: 3.1068 train_time: 5.1m tok/s: 7668374 looping:True +3500/20000 train_loss: 2.9483 train_time: 6.1m tok/s: 7510155 looping:True +layer_loop:phase2 step:3634 frac:0.650 +4000/20000 train_loss: 2.9482 train_time: 7.2m tok/s: 7297673 looping:True +4000/20000 val_loss: 2.9279 val_bpb: 1.1335 +4500/20000 train_loss: 2.8499 train_time: 8.3m tok/s: 7110716 looping:True +5000/20000 train_loss: 2.8598 train_time: 9.4m tok/s: 6967320 looping:True +5178/20000 val_loss: 2.8121 val_bpb: 1.0887 +stopping_early: wallclock_cap train_time: 588103ms step: 5178/20000 +peak memory allocated: 34604 MiB reserved: 34634 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.80947408 val_bpb:1.08764765 eval_time:6554ms +Serialized model: 135426937 bytes +Code size: 78688 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 11.3s +GPTQ:saved Hessian diagnostics to hessian_diagnostics.pt (67 matrices) +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +Serialized model quantized+brotli: 15976275 bytes +Total submission size quantized+brotli: 16054963 bytes +quantized val_loss:2.84032998 val_bpb:1.09959307 eval_time:8134ms +quantized_sliding_window val_loss:2.79749368 val_bpb:1.08300961 eval_time:82837ms diff --git a/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_seed3141.log b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_seed3141.log new file mode 100644 index 0000000000..54fa637146 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_seed3141.log @@ -0,0 +1,169 @@ +W0406 07:46:11.230000 63493 torch/distributed/run.py:803] +W0406 07:46:11.230000 63493 torch/distributed/run.py:803] ***************************************** +W0406 07:46:11.230000 63493 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0406 07:46:11.230000 63493 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + clip_mult_early: 1.0 + clip_mult_late: 1.0 + clip_mult_loop: 1.0 + clip_mult_mid: 1.0 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.997 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.5 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + hessian_clip_lambda: 0.175 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/1db830a7-adf6-4a60-a91f-a8758ece4bce.txt + logit_softcap: 30.0 + loop_end: 5 + loop_layer_bits: 0 + loop_layer_clip_sigmas: 0.0 + loop_phase2_at: 0.65 + loop_start: 4 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.085 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 4.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: 1db830a7-adf6-4a60-a91f-a8758ece4bce + scalar_lr: 0.02 + seed: 3141 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 32768 + ttt_enabled: False + ttt_entropy_high: 2.1 + ttt_entropy_low: 1.75 + ttt_epochs: 4 + ttt_freeze_blocks: 2 + ttt_lr: 0.0005 + ttt_ns_steps: 3 + untie_loop_mlps: False + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40540160 +model_params:35943512 +hessian_clip: lambda=0.175 +parallel_residuals: ON (layers 7-10) +progressive_recurrence: phase1=0.5 phase2=0.65 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup_phase1: encoder:[0, 1, 2, 3, 4, 5] decoder:[4, 5, 6, 7, 8, 9, 10] +loop_warmup_p1_step: 1/20 +loop_warmup_p1_step: 2/20 +loop_warmup_p1_step: 3/20 +loop_warmup_p1_step: 4/20 +loop_warmup_p1_step: 5/20 +loop_warmup_p1_step: 6/20 +loop_warmup_p1_step: 10/20 +loop_warmup_p1_step: 20/20 +loop_warmup_phase2: encoder:[0, 1, 2, 3, 4, 5, 4] decoder:[5, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_p2_step: 1/20 +loop_warmup_p2_step: 2/20 +loop_warmup_p2_step: 3/20 +loop_warmup_p2_step: 4/20 +loop_warmup_p2_step: 5/20 +loop_warmup_p2_step: 6/20 +loop_warmup_p2_step: 10/20 +loop_warmup_p2_step: 20/20 +0/20000 val_loss: 9.0096 val_bpb: 3.4879 +1/20000 train_loss: 9.0125 train_time: 0.0m tok/s: 8307551 looping:False +2/20000 train_loss: 12.2985 train_time: 0.0m tok/s: 8177168 looping:False +3/20000 train_loss: 11.1300 train_time: 0.0m tok/s: 8052485 looping:False +4/20000 train_loss: 9.5814 train_time: 0.0m tok/s: 7995229 looping:False +5/20000 train_loss: 8.4251 train_time: 0.0m tok/s: 7959908 looping:False +500/20000 train_loss: 3.3163 train_time: 0.8m tok/s: 7730151 looping:False +1000/20000 train_loss: 3.2014 train_time: 1.7m tok/s: 7720048 looping:False +1500/20000 train_loss: 3.1887 train_time: 2.5m tok/s: 7721291 looping:False +2000/20000 train_loss: 2.9960 train_time: 3.4m tok/s: 7720167 looping:False +2500/20000 train_loss: 3.0707 train_time: 4.2m tok/s: 7722132 looping:False +layer_loop:phase1 step:2888 frac:0.500 +3000/20000 train_loss: 3.1073 train_time: 5.1m tok/s: 7679644 looping:True +3500/20000 train_loss: 2.9516 train_time: 6.1m tok/s: 7521178 looping:True +layer_loop:phase2 step:3639 frac:0.650 +4000/20000 train_loss: 2.9475 train_time: 7.2m tok/s: 7307604 looping:True +4000/20000 val_loss: 2.9309 val_bpb: 1.1347 +4500/20000 train_loss: 2.8551 train_time: 8.3m tok/s: 7119273 looping:True +5000/20000 train_loss: 2.8637 train_time: 9.4m tok/s: 6974987 looping:True +5182/20000 val_loss: 2.8149 val_bpb: 1.0897 +stopping_early: wallclock_cap train_time: 588016ms step: 5182/20000 +peak memory allocated: 34604 MiB reserved: 34636 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.81224419 val_bpb:1.08872006 eval_time:6671ms +Serialized model: 135426937 bytes +Code size: 80055 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 11.3s +GPTQ:saved Hessian diagnostics to hessian_diagnostics.pt (67 matrices) +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +Serialized model quantized+brotli: 15979649 bytes +Total submission size quantized+brotli: 16059704 bytes +quantized val_loss:2.84252324 val_bpb:1.10044216 eval_time:8167ms +quantized_sliding_window val_loss:2.80002165 val_bpb:1.08398828 eval_time:83092ms diff --git a/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_seed42.log b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_seed42.log new file mode 100644 index 0000000000..f8815ecda0 --- /dev/null +++ b/records/track_10min_16mb/2026-04-06_SP8192_HessianSDClip_ProgressiveRecurrence/train_seed42.log @@ -0,0 +1,169 @@ +W0406 06:54:37.638000 61576 torch/distributed/run.py:803] +W0406 06:54:37.638000 61576 torch/distributed/run.py:803] ***************************************** +W0406 06:54:37.638000 61576 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0406 06:54:37.638000 61576 torch/distributed/run.py:803] ***************************************** +Hyperparameters: + adam_eps: 1e-08 + adam_wd: 0.02 + beta1: 0.9 + beta2: 0.95 + clip_mult_early: 1.0 + clip_mult_late: 1.0 + clip_mult_loop: 1.0 + clip_mult_mid: 1.0 + compressor: brotli + data_dir: ./data/ + datasets_dir: ./data/datasets/fineweb10B_sp8192 + distributed: True + ema_decay: 0.997 + embed_bits: 8 + embed_clip_sigmas: 20.0 + embed_lr: 0.6 + embed_wd: 0.085 + embedding_dim: 512 + enable_looping_at: 0.5 + eval_seq_len: 2048 + eval_stride: 64 + gptq_calibration_batches: 64 + gptq_reserve_seconds: 12.0 + grad_accum_steps: 1 + grad_clip_norm: 0.3 + head_lr: 0.008 + hessian_clip_lambda: 0.175 + is_main_process: True + iterations: 20000 + ln_scale: True + local_rank: 0 + logfile: logs/fe8e0f78-33ff-4666-9129-1790da679899.txt + logit_softcap: 30.0 + loop_end: 5 + loop_layer_bits: 0 + loop_layer_clip_sigmas: 0.0 + loop_phase2_at: 0.65 + loop_start: 4 + matrix_bits: 6 + matrix_clip_sigmas: 12.85 + matrix_lr: 0.02 + max_wallclock_seconds: 600.0 + min_lr: 0.0 + mlp_mult: 4.0 + model_dim: 512 + model_path: final_model.pt + muon_backend_steps: 5 + muon_beta2: 0.95 + muon_momentum: 0.99 + muon_momentum_warmup_start: 0.92 + muon_momentum_warmup_steps: 1500 + muon_row_normalize: True + muon_wd: 0.085 + num_heads: 8 + num_kv_heads: 4 + num_layers: 11 + num_loops: 2 + parallel_residual_start: 7 + qk_gain_init: 4.0 + quantized_model_path: final_model.int6.ptz + rank: 0 + rope_base: 10000.0 + rope_dims: 16 + rope_train_seq_len: 2048 + run_id: fe8e0f78-33ff-4666-9129-1790da679899 + scalar_lr: 0.02 + seed: 42 + skip_gates_enabled: True + sliding_window_enabled: True + tie_embeddings: True + tied_embed_init_std: 0.005 + tied_embed_lr: 0.03 + tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model + train_batch_tokens: 786432 + train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin + train_log_every: 500 + train_seq_len: 2048 + ttt_chunk_tokens: 32768 + ttt_enabled: False + ttt_entropy_high: 2.1 + ttt_entropy_low: 1.75 + ttt_epochs: 4 + ttt_freeze_blocks: 2 + ttt_lr: 0.0005 + ttt_ns_steps: 3 + untie_loop_mlps: False + val_batch_tokens: 524288 + val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin + val_loss_every: 4000 + vocab_size: 8192 + warmdown_frac: 0.667 + warmup_steps: 20 + world_size: 8 + xsa_last_n: 11 +train_shards: 128 +val_tokens: 40540160 +model_params:35943512 +hessian_clip: lambda=0.175 +parallel_residuals: ON (layers 7-10) +progressive_recurrence: phase1=0.5 phase2=0.65 +gptq:reserving 12s, effective=588000ms +warmup_step: 1/20 +warmup_step: 2/20 +warmup_step: 3/20 +warmup_step: 4/20 +warmup_step: 5/20 +warmup_step: 6/20 +warmup_step: 10/20 +warmup_step: 20/20 +loop_warmup_phase1: encoder:[0, 1, 2, 3, 4, 5] decoder:[4, 5, 6, 7, 8, 9, 10] +loop_warmup_p1_step: 1/20 +loop_warmup_p1_step: 2/20 +loop_warmup_p1_step: 3/20 +loop_warmup_p1_step: 4/20 +loop_warmup_p1_step: 5/20 +loop_warmup_p1_step: 6/20 +loop_warmup_p1_step: 10/20 +loop_warmup_p1_step: 20/20 +loop_warmup_phase2: encoder:[0, 1, 2, 3, 4, 5, 4] decoder:[5, 4, 5, 6, 7, 8, 9, 10] +loop_warmup_p2_step: 1/20 +loop_warmup_p2_step: 2/20 +loop_warmup_p2_step: 3/20 +loop_warmup_p2_step: 4/20 +loop_warmup_p2_step: 5/20 +loop_warmup_p2_step: 6/20 +loop_warmup_p2_step: 10/20 +loop_warmup_p2_step: 20/20 +0/20000 val_loss: 9.0089 val_bpb: 3.4877 +1/20000 train_loss: 9.0121 train_time: 0.0m tok/s: 8297317 looping:False +2/20000 train_loss: 12.2971 train_time: 0.0m tok/s: 8179522 looping:False +3/20000 train_loss: 11.0982 train_time: 0.0m tok/s: 8073126 looping:False +4/20000 train_loss: 9.5163 train_time: 0.0m tok/s: 8021123 looping:False +5/20000 train_loss: 8.4080 train_time: 0.0m tok/s: 7978901 looping:False +500/20000 train_loss: 3.3093 train_time: 0.8m tok/s: 7721491 looping:False +1000/20000 train_loss: 3.1978 train_time: 1.7m tok/s: 7712198 looping:False +1500/20000 train_loss: 3.1825 train_time: 2.5m tok/s: 7713820 looping:False +2000/20000 train_loss: 2.9967 train_time: 3.4m tok/s: 7716076 looping:False +2500/20000 train_loss: 3.0687 train_time: 4.2m tok/s: 7716771 looping:False +layer_loop:phase1 step:2886 frac:0.500 +3000/20000 train_loss: 3.1027 train_time: 5.1m tok/s: 7672009 looping:True +3500/20000 train_loss: 2.9498 train_time: 6.1m tok/s: 7513271 looping:True +layer_loop:phase2 step:3635 frac:0.650 +4000/20000 train_loss: 2.9502 train_time: 7.2m tok/s: 7299379 looping:True +4000/20000 val_loss: 2.9293 val_bpb: 1.1340 +4500/20000 train_loss: 2.8533 train_time: 8.3m tok/s: 7111295 looping:True +5000/20000 train_loss: 2.8595 train_time: 9.4m tok/s: 6967776 looping:True +5178/20000 val_loss: 2.8133 val_bpb: 1.0891 +stopping_early: wallclock_cap train_time: 588098ms step: 5178/20000 +peak memory allocated: 34604 MiB reserved: 34634 MiB +ema:applying EMA weights +pre-quantization post-ema val_loss:2.81078864 val_bpb:1.08815657 eval_time:6623ms +Serialized model: 135426937 bytes +Code size: 80055 bytes +GPTQ:collecting Hessians from calibration data... +GPTQ:collected 67 Hessians in 11.4s +GPTQ:saved Hessian diagnostics to hessian_diagnostics.pt (67 matrices) +Quantized weights: + gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight + gptq (int8): tok_emb.weight + passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights +Serialized model quantized+brotli: 15978439 bytes +Total submission size quantized+brotli: 16058494 bytes +quantized val_loss:2.84171694 val_bpb:1.10013002 eval_time:8159ms +quantized_sliding_window val_loss:2.79909586 val_bpb:1.08362987 eval_time:83408ms