diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/README.md b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/README.md new file mode 100644 index 000000000..34981b2bb --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/README.md @@ -0,0 +1,241 @@ +# Entropy-Aware Int5-Odd + BTT-MLP + +## Summary + +This non-record submission packages the current best fully recovered run of two ideas aimed directly at Parameter Golf's artifact objective: + +1. **Entropy-aware training** over a 5-bin odd quantization grid `{-2,-1,0,1,2}` aligned to the exported `int5_odd + zlib` artifact. +2. **Structured MLP matrices** using a 2-core TT/BTT-inspired `StructuredLinear` in the MLP only, leaving attention dense. +3. **Evaluation-time materialization** of the structured MLP into dense weights, so validation runs through standard `F.linear` instead of the slower TT/BTT rank loop. + +The submission basis is the fully recovered `1xH100` run in [`train.log`](./train.log), which is a copy of [`logs/h100_real_r256_l16_seq1024_mb2048_materialized.txt`](./logs/h100_real_r256_l16_seq1024_mb2048_materialized.txt). This is an exploratory non-record result, not a leaderboard-comparable run, because it uses `VAL_TOKEN_LIMIT=1048576` rather than the full validation set. + +A later `3300s` H100 attempt progressed further but was interrupted when RunPod exhausted the remaining credits. Because the final artifact and complete log were not recovered, that longer run is not used for the submission metrics here. + +## Cloud Transition + +The harness is now prepared for a single-node `torchrun` launch on `8xH100` without changing the local-tuned default path. + +New training controls in [`train_gpt.py`](./train_gpt.py): + +- `TRAIN_MICROBATCH_TOKENS`: per-GPU tokens per microstep +- `GRAD_ACCUM_STEPS`: optional explicit override; otherwise inferred from `TRAIN_BATCH_TOKENS`, `WORLD_SIZE`, and `TRAIN_MICROBATCH_TOKENS` +- `LR_SCALE_METHOD=none|sqrt|linear` +- `LR_REFERENCE_BATCH_TOKENS` +- `WARMUP_SCALE_METHOD=none|sqrt|linear` +- `WARMUP_REFERENCE_BATCH_TOKENS` + +Recommended single-node launcher: + +```bash +bash run_8xh100.sh +``` + +The launcher uses: + +- `torchrun --standalone --nproc_per_node=8` +- `TRAIN_BATCH_TOKENS=131072` +- `TRAIN_MICROBATCH_TOKENS=8192` +- `LR_SCALE_METHOD=sqrt` +- `WARMUP_SCALE_METHOD=linear` +- `VAL_TOKEN_LIMIT=0` + +This keeps DDP simple, scales the optimizer from the local `16384`-token reference batch, and increases warmup automatically when the effective global batch is larger. + +### DDP data sharding + +The training loader is rank-aware. Each process constructs the same contiguous shared chunk and then slices out its own disjoint span using `rank` and `world_size`, so gradients are synchronized over different token ranges rather than duplicated work. + +The sharding debug path is built into [`train_gpt.py`](./train_gpt.py) via: + +- `SIMULATED_WORLD_SIZE` +- `SIMULATED_RANK` +- `DEBUG_DATA_SHARDING_STEPS` +- `DRY_RUN_INIT_ONLY` + +Local sharding/math sanity check: + +```bash +bash run_mock_8xh100_math.sh +``` + +This prints: + +- inferred `grad_accum_steps` +- scaled learning rates and warmup +- one simulated shared chunk split across ranks `0..7` + +The checked local run is recorded in [`logs/mock_8xh100_math.txt`](./logs/mock_8xh100_math.txt). + +Gradient-accumulation stress check: + +```bash +bash run_extreme_accum.sh +``` + +The checked run in [`logs/extreme_accum.txt`](./logs/extreme_accum.txt) completed with `grad_accum_steps=64` and peaked at `4617 MiB` allocated on the 3060. + +## Main Result + +Submission command: + +```bash +RUN_ID=h100_real_r256_l16_seq1024_mb2048_materialized \ +DATA_PATH=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +NUM_LAYERS=16 \ +BTT_RANK=256 \ +TRAIN_SEQ_LEN=1024 \ +TRAIN_BATCH_TOKENS=16384 \ +TRAIN_MICROBATCH_TOKENS=2048 \ +VAL_BATCH_SIZE=1048576 \ +VAL_TOKEN_LIMIT=1048576 \ +COMPILE_STRUCTURED_MLP=0 \ +ITERATIONS=20000 \ +MAX_WALLCLOCK_SECONDS=600 \ +python3 train_gpt.py +``` + +Recovered H100 metrics on `1x H100 80GB HBM3` with `80` FineWeb train shards and `VAL_TOKEN_LIMIT=1048576`: + +| Variant | Params | Step Avg | Pre-Quant val_bpb | Roundtrip val_bpb | Quantized Size | Total Size | +|---|---:|---:|---:|---:|---:|---:| +| BTT-MLP + entropy-aware `int5_odd` + eval materialization | 25,727,104 | 19468.45 ms | 5.3457 | 5.8880 | 5,184,543 B | 5,267,667 B | + +Additional notes: + +- Peak memory: `50290 MiB allocated / 78558 MiB reserved` +- Optimizer-step training time: `603.522 s` for `31` steps +- Final exact roundtrip metrics: `val_loss=9.82734489`, `val_bpb=5.88802157` +- Quantized eval time on the `1,048,576` token validation cap: `854 ms` +- This run is the basis for `submission.json` and `train.log` + +## Evaluation Materialization Result + +The key systems win in this folder is the eval-time materialization hook in [`train_gpt.py`](./train_gpt.py), which contracts the BTT cores into a dense matrix during validation and runs `F.linear` instead of the training-time structured path. + +The direct benchmark is in [`logs/eval_bench_r256_l16_materialized_vb1048576.txt`](./logs/eval_bench_r256_l16_materialized_vb1048576.txt): + +| Eval Path | Validation Tokens | Eval Time | +|---|---:|---:| +| Structured rank-loop path, cached | 1,048,576 | ~96.99 s | +| Materialized dense eval path | 1,048,576 | 0.851 s | + +This is the change that made the structured submission operationally viable on H100 for evaluation. + +## Local Result + +The local 3060 smoke result is still kept in [`logs/smoke_btt_40.txt`](./logs/smoke_btt_40.txt) and remains useful as a low-cost regression test for future iterations. + +## Dense Control + +The dense reference run in [`dense_ablation.log`](./dense_ablation.log) was kept as a sanity check for the export path. + +| Variant | Params | Step Avg | Pre-Quant val_bpb | Roundtrip val_bpb | Quantized Size | Total Size | +|---|---:|---:|---:|---:|---:|---:| +| Dense + entropy-aware `int5_odd` | 17,059,912 | 595.10 ms | 2.7479 | 3.1177 | 3,781,677 B | 3,847,186 B | + +The structured path still loses too much quality relative to dense at equal training budget, so this remains exploratory rather than competitive. + +## Key Debugging Findings + +### 1. Safe compiled structured MLP + +Naively wrapping the BTT MLP in `torch.compile(mode="reduce-overhead")` crashed under gradient accumulation because the mode enables CUDAGraphs by default in this torch build. + +The fix in [`train_gpt.py`](./train_gpt.py): + +- marks compile step boundaries explicitly before compiled model invocations +- compiles the structured MLP with the `reduce-overhead` option set but forces `triton.cudagraphs=False` + +Measured 8-step benchmark logs: + +- [`logs/compile_off_bench.txt`](./logs/compile_off_bench.txt) +- [`logs/compile_on_bench.txt`](./logs/compile_on_bench.txt) + +| Compile Structured MLP | Step Avg | Roundtrip val_bpb | +|---|---:|---:| +| Off | 5037.16 ms | 6.0884 | +| On, safe non-CUDAGraph path | 1991.68 ms | 4.3758 | + +The main takeaway is speed: the compiled structured path is about **2.53x faster** on the 3060. + +### 2. `mup` init is slightly better than xavier + +Short 12-step init ablation logs: + +- [`logs/init_mup_bench.txt`](./logs/init_mup_bench.txt) +- [`logs/init_xavier_bench.txt`](./logs/init_xavier_bench.txt) + +| BTT Init | Pre-Quant val_bpb | Roundtrip val_bpb | +|---|---:|---:| +| `mup` | 3.7575 | 3.9927 | +| `xavier` | 3.7596 | 4.0048 | + +The gain is small but consistent, so `mup` remains the default. + +### 3. Lower rate penalty is the better local default + +The original entropy penalty was too aggressive for this small structured model. + +Matched 40-step comparison: + +- default tuned run: [`train.log`](./train.log) +- higher-penalty ablation: [`logs/lambda_high_40.txt`](./logs/lambda_high_40.txt) + +| RATE_LAMBDA | Pre-Quant val_bpb | Roundtrip val_bpb | Total Size | +|---|---:|---:|---:| +| `0.00002` | 3.2636 | 3.4274 | 2,168,889 B | +| `0.002` | 3.2655 | 3.4250 | 2,159,147 B | + +The short sweep favored the lower penalty, but the matched 40-step local comparison now gives a slight roundtrip edge to `0.002`. Keep both settings as live candidates before the cloud run. + +The shorter sweep harness is kept in [`run_lambda_sweep.sh`](./run_lambda_sweep.sh). + +### 4. The implementation can now spend the artifact budget + +The early structured runs were far too small. After increasing BTT rank and depth, the same export path can reach the intended `12MB–14MB` band. + +Zero-step capacity scouts: + +- [`logs/R1024_L20.txt`](./logs/R1024_L20.txt) +- [`logs/R1024_L24.txt`](./logs/R1024_L24.txt) + +| Config | Params | Total Size | Roundtrip val_bpb | +|---|---:|---:|---:| +| `BTT_RANK=1024 NUM_LAYERS=20` | 79,213,728 | 11,224,259 B | 3.7634 | +| `BTT_RANK=1024 NUM_LAYERS=24` | 94,951,616 | 13,435,175 B | 3.7671 | + +These are **capacity scouts only**, not trained submissions, but they show the BTT stack can occupy a realistic non-record budget before moving to cloud hardware. + +## Local Workflow + +Helper scripts included in this folder: + +- [`run_compile_bench.sh`](./run_compile_bench.sh): compile on/off benchmark for the structured MLP +- [`run_init_ablation.sh`](./run_init_ablation.sh): `mup` vs `xavier` +- [`run_lambda_sweep.sh`](./run_lambda_sweep.sh): entropy-penalty sweep with the compiled path enabled +- [`run_capacity_scout.sh`](./run_capacity_scout.sh): quick size scouts up to the 12MB–14MB regime +- [`run_8xh100.sh`](./run_8xh100.sh): single-node `8xH100` torchrun launcher with batch/LR/warmup scaling defaults +- [`run_mock_8xh100_math.sh`](./run_mock_8xh100_math.sh): local dry run for 8-GPU batch math plus rank-aware sharding inspection +- [`run_extreme_accum.sh`](./run_extreme_accum.sh): 50+ microstep accumulation stress test on one GPU +- [`run_local_tuning.sh`](./run_local_tuning.sh): runs the full local tuning sequence + +## Current Limitations + +- Validation is still capped locally with `VAL_TOKEN_LIMIT=1048576` by default for tractable iteration on a 3060. +- The BTT forward path still uses a rank loop to stay memory-safe during validation and export. That keeps it correct but slower than a production-quality fused implementation. +- The entropy-aware objective is a practical rate proxy aligned to this repo's serializer, not a literal reproduction of BackSlash or CERWU. +- The high-rank `12MB–14MB` configurations are not yet trained locally; they are cloud candidates. +- Validation batch sizing is now independent of gradient accumulation, which matters for high-accumulation debug runs and cloud launches. + +## Future Work + +The most immediate next step is to rerun the same materialized-eval stack on `1xH100` with a longer wallclock budget. A later `3300s` H100 attempt was started after the recovered `600s` run and was progressing in the right direction before RunPod exhausted the remaining credits: the last recovered checkpoint in the live log reached `step 170` at `train_time 3108679ms` with `train_loss 4.9575`, and earlier in that same run it had already improved from `train_loss 5.3368` at step `90` to `4.8696` at step `160`. Because the final artifact and complete tail of the log were not recovered, that longer run is not used in `submission.json`, but it is the clearest indicator for where to spend the next block of compute. + +Concretely, the next cloud pass should: + +- rerun `R256/L16` with the materialized eval path and a longer wallclock budget +- recover the full final artifact and log from that longer run +- then retest higher-capacity scouts such as `R1024/L20` and `R1024/L24` once the structured training path is fast enough to give those models meaningful token exposure diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/dense_ablation.log b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/dense_ablation.log new file mode 100644 index 000000000..4690c497b --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/dense_ablation.log @@ -0,0 +1,3209 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 0)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "1"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.02)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0005)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.2)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.6)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "0"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if t.ndim != 2 or t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor_int5_odd(t) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + out[name] = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if p.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return True + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(m1, self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + std = 1.0 / math.sqrt(self.in_features) + nn.init.normal_(self.core1, mean=0.0, std=std / math.sqrt(self.rank)) + nn.init.normal_(self.core2, mean=0.0, std=std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + tmp = torch.einsum("buv,umr->bmrv", x2, core1) + y = torch.einsum("bmrv,mrwn->bnw", tmp, core2) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, structured_mlp, structured_linear_type, btt_rank, btt_in_shape, btt_out_shape) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 02:49:12 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 36C P8 19W / 170W | 214MiB / 12288MiB | 8% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 975 G /usr/lib/xorg/Xorg 89MiB | +| 0 N/A N/A 4107 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:17059912 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:False structured_linear_type:btt_mlp +rate_lambda:0.002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/40 val_loss:6.9354 val_bpb:4.1553 train_time:0ms step_avg:0.02ms +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 0)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "1"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.02)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0005)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.2)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.6)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "0"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if t.ndim != 2 or t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor_int5_odd(t) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + out[name] = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if p.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return True + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(m1, self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + std = 1.0 / math.sqrt(self.in_features) + nn.init.normal_(self.core1, mean=0.0, std=std / math.sqrt(self.rank)) + nn.init.normal_(self.core2, mean=0.0, std=std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + tmp = torch.einsum("buv,umr->bmrv", x2, core1) + y = torch.einsum("bmrv,mrwn->bnw", tmp, core2) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, structured_mlp, structured_linear_type, btt_rank, btt_in_shape, btt_out_shape) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 02:49:53 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 34C P5 20W / 170W | 214MiB / 12288MiB | 36% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 975 G /usr/lib/xorg/Xorg 89MiB | +| 0 N/A N/A 4130 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:17059912 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:False structured_linear_type:btt_mlp +rate_lambda:0.002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/40 val_loss:6.9354 val_bpb:4.1553 train_time:0ms step_avg:0.02ms +step:1/40 train_loss:6.9416 ce_loss:6.9416 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:528ms step_avg:527.52ms +step:2/40 train_loss:12.9378 ce_loss:12.9378 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:944ms step_avg:472.13ms +step:3/40 train_loss:9.3795 ce_loss:9.3795 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:1358ms step_avg:452.77ms +step:4/40 train_loss:7.0969 ce_loss:7.0969 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:1766ms step_avg:441.44ms +step:5/40 train_loss:6.7238 ce_loss:6.7238 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:2178ms step_avg:435.64ms +step:6/40 train_loss:6.7849 ce_loss:6.7845 rate_mul:0.083 rate_proxy_bits:2.2322 scale_proxy:0.012163 est_model_bits:21065706 train_time:2855ms step_avg:475.79ms +step:7/40 train_loss:6.6324 ce_loss:6.6317 rate_mul:0.167 rate_proxy_bits:2.2312 scale_proxy:0.012415 est_model_bits:21055872 train_time:3475ms step_avg:496.49ms +step:8/40 train_loss:6.5181 ce_loss:6.5170 rate_mul:0.250 rate_proxy_bits:2.2301 scale_proxy:0.012661 est_model_bits:21045870 train_time:4093ms step_avg:511.59ms +step:9/40 train_loss:6.3127 ce_loss:6.3113 rate_mul:0.333 rate_proxy_bits:2.2291 scale_proxy:0.012896 est_model_bits:21036376 train_time:4709ms step_avg:523.24ms +step:10/40 train_loss:6.2998 ce_loss:6.2979 rate_mul:0.417 rate_proxy_bits:2.2281 scale_proxy:0.013118 est_model_bits:21026974 train_time:5334ms step_avg:533.37ms +step:20/40 train_loss:5.3163 ce_loss:5.3118 rate_mul:1.000 rate_proxy_bits:2.2187 scale_proxy:0.014803 est_model_bits:20937854 train_time:11510ms step_avg:575.51ms +step:30/40 train_loss:4.8916 ce_loss:4.8872 rate_mul:1.000 rate_proxy_bits:2.2128 scale_proxy:0.015605 est_model_bits:20882642 train_time:17688ms step_avg:589.61ms +step:40/40 train_loss:4.5637 ce_loss:4.5593 rate_mul:1.000 rate_proxy_bits:2.2093 scale_proxy:0.016088 est_model_bits:20849622 train_time:23804ms step_avg:595.09ms +step:40/40 val_loss:4.5863 val_bpb:2.7479 train_time:23804ms step_avg:595.10ms +peak memory allocated: 1298 MiB reserved: 2002 MiB +Serialized model: 67224983 bytes +Code size: 65509 bytes +Total submission size: 67290492 bytes +Serialized model int5_odd+zlib: 3781677 bytes (payload:5819363 raw_torch:5850255 payload_ratio:11.55x) +Total submission size int5_odd+zlib: 3847186 bytes +final_int5_odd_zlib_roundtrip val_loss:5.2035 val_bpb:3.1177 eval_time:4547ms +final_int5_odd_zlib_roundtrip_exact val_loss:5.20348889 val_bpb:3.11765336 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/R1024_L20.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/R1024_L20.txt new file mode 100644 index 000000000..b5e4a7c83 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/R1024_L20.txt @@ -0,0 +1,3380 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:19:48 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 34C P8 19W / 170W | 242MiB / 12288MiB | 7% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 4831 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:79213728 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:1024 btt_init:mup compile_structured_mlp:False +rate_lambda:0.0002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:0 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/0 val_loss:6.9402 val_bpb:4.1035 train_time:0ms step_avg:0.02ms +peak memory allocated: 646 MiB reserved: 686 MiB +Serialized model: 315893563 bytes +Code size: 69179 bytes +Total submission size: 315962742 bytes +Serialized model int5_odd+zlib: 11155080 bytes (payload:26721531 raw_torch:26810395 payload_ratio:11.82x) +Total submission size int5_odd+zlib: 11224259 bytes +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:25:14 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 Off | N/A | +| 33% 42C P8 18W / 170W | 242MiB / 12288MiB | 10% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 4885 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:2048 +val_token_limit:2048 +model_params:79213728 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:1024 btt_init:mup compile_structured_mlp:False +rate_lambda:0.0002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:0 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/0 val_loss:6.9448 val_bpb:3.7609 train_time:0ms step_avg:0.02ms +peak memory allocated: 320 MiB reserved: 334 MiB +Serialized model: 315893563 bytes +Code size: 69179 bytes +Total submission size: 315962742 bytes +Serialized model int5_odd+zlib: 11155080 bytes (payload:26721531 raw_torch:26810395 payload_ratio:11.82x) +Total submission size int5_odd+zlib: 11224259 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9494 val_bpb:3.7634 eval_time:71757ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94941086 val_bpb:3.76338039 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/R1024_L24.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/R1024_L24.txt new file mode 100644 index 000000000..ef8fee420 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/R1024_L24.txt @@ -0,0 +1,1691 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:27:57 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 Off | N/A | +| 30% 35C P8 17W / 170W | 242MiB / 12288MiB | 5% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 4909 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:2048 +val_token_limit:2048 +model_params:94951616 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:1024 btt_init:mup compile_structured_mlp:False +rate_lambda:0.0002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:0 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/0 val_loss:6.9495 val_bpb:3.7634 train_time:0ms step_avg:0.02ms +peak memory allocated: 381 MiB reserved: 408 MiB +Serialized model: 378862427 bytes +Code size: 69179 bytes +Total submission size: 378931606 bytes +Serialized model int5_odd+zlib: 13365996 bytes (payload:32030475 raw_torch:32137147 payload_ratio:11.82x) +Total submission size int5_odd+zlib: 13435175 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9562 val_bpb:3.7671 eval_time:85577ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.95620948 val_bpb:3.76706212 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/cloud_prep_local.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/cloud_prep_local.txt new file mode 100644 index 000000000..95d753dc7 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/cloud_prep_local.txt @@ -0,0 +1,1762 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 13:06:00 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 43C P0 19W / 170W | 242MiB / 12288MiB | 2% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 5616 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:8507464 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:True +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:256 iterations:1 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:1/1 train_loss:6.9388 ce_loss:6.9388 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:1588ms step_avg:1588.47ms +step:1/1 val_loss:12.9889 val_bpb:7.6798 train_time:1590ms step_avg:1589.87ms +peak memory allocated: 2356 MiB reserved: 2516 MiB +Serialized model: 33022259 bytes +Code size: 72970 bytes +Total submission size: 33095229 bytes +Serialized model int5_odd+zlib: 2134652 bytes (payload:2942927 raw_torch:2983011 payload_ratio:11.21x) +Total submission size int5_odd+zlib: 2207622 bytes +final_int5_odd_zlib_roundtrip val_loss:12.4526 val_bpb:7.3628 eval_time:1865ms +final_int5_odd_zlib_roundtrip_exact val_loss:12.45264053 val_bpb:7.36275413 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/cloud_prep_torchrun.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/cloud_prep_torchrun.txt new file mode 100644 index 000000000..bdf9fe3c0 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/cloud_prep_torchrun.txt @@ -0,0 +1,1761 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 13:06:34 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 32% 41C P8 26W / 170W | 716MiB / 12288MiB | 27% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 5676 C /usr/bin/python3 578MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:2048 +val_token_limit:2048 +model_params:8507464 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:True +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:256 iterations:0 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:0/0 val_loss:6.9412 val_bpb:3.7589 train_time:0ms step_avg:0.02ms +peak memory allocated: 2388 MiB reserved: 2486 MiB +Serialized model: 33022259 bytes +Code size: 72970 bytes +Total submission size: 33095229 bytes +Serialized model int5_odd+zlib: 1602383 bytes (payload:2942927 raw_torch:2983011 payload_ratio:11.21x) +Total submission size int5_odd+zlib: 1675353 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9472 val_bpb:3.7622 eval_time:636ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94722307 val_bpb:3.76219562 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/compile_bench.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/compile_bench.txt new file mode 100644 index 000000000..7de29ebf5 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/compile_bench.txt @@ -0,0 +1,1660 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + self.fc = torch.compile(self.fc, mode="reduce-overhead", fullgraph=False, dynamic=False) + self.proj = torch.compile(self.proj, mode="reduce-overhead", fullgraph=False, dynamic=False) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:09:34 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 39C P8 20W / 170W | 242MiB / 12288MiB | 6% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 3793 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:8507464 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:True +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:4 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/4 val_loss:6.9362 val_bpb:4.1011 train_time:0ms step_avg:0.02ms diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/compile_off_bench.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/compile_off_bench.txt new file mode 100644 index 000000000..8a41b1604 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/compile_off_bench.txt @@ -0,0 +1,1688 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + self.fc = torch.compile(self.fc, mode="reduce-overhead", fullgraph=False, dynamic=False) + self.proj = torch.compile(self.proj, mode="reduce-overhead", fullgraph=False, dynamic=False) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:13:08 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 33C P0 18W / 170W | 242MiB / 12288MiB | 2% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 4485 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:8507464 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:False +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:8 warmup_steps:1 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/1 +step:1/8 train_loss:6.9357 ce_loss:6.9357 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:4607ms step_avg:4607.43ms +step:2/8 train_loss:14.0267 ce_loss:14.0267 rate_mul:0.083 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:9734ms step_avg:4867.25ms +step:3/8 train_loss:13.7511 ce_loss:13.7511 rate_mul:0.500 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:14816ms step_avg:4938.66ms +step:4/8 train_loss:13.2375 ce_loss:13.2375 rate_mul:0.917 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:19861ms step_avg:4965.25ms +step:5/8 train_loss:12.3534 ce_loss:12.3533 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:24954ms step_avg:4990.80ms +step:6/8 train_loss:11.8068 ce_loss:11.8067 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:30078ms step_avg:5013.03ms +step:7/8 train_loss:11.0009 ce_loss:11.0008 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:35163ms step_avg:5023.26ms +step:8/8 train_loss:10.2193 ce_loss:10.2193 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:40297ms step_avg:5037.11ms +step:8/8 val_loss:9.5285 val_bpb:5.6338 train_time:40297ms step_avg:5037.16ms +peak memory allocated: 1133 MiB reserved: 1390 MiB +Serialized model: 33020979 bytes +Code size: 68825 bytes +Total submission size: 33089804 bytes +Serialized model int5_odd+zlib: 2169740 bytes (payload:2942927 raw_torch:2982627 payload_ratio:11.21x) +Total submission size int5_odd+zlib: 2238565 bytes +final_int5_odd_zlib_roundtrip val_loss:10.2974 val_bpb:6.0884 eval_time:3355ms +final_int5_odd_zlib_roundtrip_exact val_loss:10.29740024 val_bpb:6.08844573 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/compile_on_bench.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/compile_on_bench.txt new file mode 100644 index 000000000..24bb3a9d4 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/compile_on_bench.txt @@ -0,0 +1,5043 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + self.fc = torch.compile(self.fc, mode="reduce-overhead", fullgraph=False, dynamic=False) + self.proj = torch.compile(self.proj, mode="reduce-overhead", fullgraph=False, dynamic=False) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:14:11 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 31% 38C P8 19W / 170W | 252MiB / 12288MiB | 9% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 127MiB | +| 0 N/A N/A 4507 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:8507464 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:True +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:8 warmup_steps:1 max_wallclock_seconds:600.000 +seed:1337 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = {"triton.cudagraphs": False} + self.fc = torch.compile( + self.fc, + mode="reduce-overhead", + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + mode="reduce-overhead", + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:15:23 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 33C P8 19W / 170W | 242MiB / 12288MiB | 8% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 4638 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:15:46 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 33C P8 19W / 170W | 242MiB / 12288MiB | 8% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 4664 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:8507464 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:True +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:8 warmup_steps:1 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/1 +step:1/8 train_loss:6.9357 ce_loss:6.9357 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:1680ms step_avg:1680.21ms +step:2/8 train_loss:14.0267 ce_loss:14.0267 rate_mul:0.083 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:3781ms step_avg:1890.64ms +step:3/8 train_loss:13.2241 ce_loss:13.2241 rate_mul:0.500 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:5832ms step_avg:1944.12ms +step:4/8 train_loss:11.8367 ce_loss:11.8366 rate_mul:0.917 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:7854ms step_avg:1963.61ms +step:5/8 train_loss:10.1047 ce_loss:10.1046 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:9887ms step_avg:1977.30ms +step:6/8 train_loss:9.2923 ce_loss:9.2922 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:11911ms step_avg:1985.11ms +step:7/8 train_loss:8.2568 ce_loss:8.2568 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:13905ms step_avg:1986.46ms +step:8/8 train_loss:7.4875 ce_loss:7.4875 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:15933ms step_avg:1991.63ms +step:8/8 val_loss:7.0360 val_bpb:4.1601 train_time:15933ms step_avg:1991.68ms +peak memory allocated: 736 MiB reserved: 1422 MiB +Serialized model: 33022259 bytes +Code size: 69179 bytes +Total submission size: 33091438 bytes +Serialized model int5_odd+zlib: 2139278 bytes (payload:2942927 raw_torch:2983011 payload_ratio:11.21x) +Total submission size int5_odd+zlib: 2208457 bytes +final_int5_odd_zlib_roundtrip val_loss:7.4007 val_bpb:4.3758 eval_time:1856ms +final_int5_odd_zlib_roundtrip_exact val_loss:7.40071642 val_bpb:4.37575109 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_materialized_vb1048576.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_materialized_vb1048576.txt new file mode 100644 index 000000000..a8bd08e02 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_materialized_vb1048576.txt @@ -0,0 +1,1931 @@ +logs/eval_bench_r256_l16_materialized_vb1048576.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:0 warmup_steps:0 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:0/0 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.02ms +peak memory allocated: 22735 MiB reserved: 28940 MiB +Serialized model: 101929819 bytes +Code size: 83124 bytes +Total submission size: 102012943 bytes +Serialized model int5_odd+zlib: 4207457 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 4290581 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9416 val_bpb:4.1590 eval_time:851ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94156265 val_bpb:4.15901457 +ence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_eval_chunk_ranks = int(os.environ.get("STRUCTURED_EVAL_CHUNK_RANKS", 1)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + eval_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + eval_module = model.module if isinstance(model, DDP) else model + for module in eval_module.modules(): + if isinstance(module, StructuredLinear): + module.materialize(eval_dtype, device) + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias=bias) + self._cached_weight: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int | None] | None = None + + def _cast_params(self, x: Tensor) -> tuple[Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.weight._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_weight, self._cached_bias + weight = self.weight.to(dtype=x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_weight = weight + self._cached_bias = bias + self._cached_key = cache_key + return weight, bias + + def train(self, mode: bool = True): + if mode: + self._cached_weight = None + self._cached_bias = None + self._cached_key = None + return super().train(mode) + + def forward(self, x: Tensor) -> Tensor: + weight, bias = self._cast_params(x) + return F.linear(x, weight, bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + eval_chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.eval_chunk_ranks = max(int(eval_chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + self._cached_core1: Tensor | None = None + self._cached_core2: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + self._dense_eval_weight: Tensor | None = None + self._dense_eval_bias: Tensor | None = None + self._dense_eval_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def _cast_factors(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.core1._version, self.core2._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_core1, self._cached_core2, self._cached_bias + core1 = self.core1.to(dtype=x.dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x.dtype).contiguous() + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_core1 = core1 + self._cached_core2 = core2 + self._cached_bias = bias + self._cached_key = cache_key + return core1, core2, bias + + def materialize(self, dtype: torch.dtype, device: torch.device) -> None: + if self.impl is not None: + return + bias_version = self.bias._version if self.bias is not None else None + cache_key = (dtype, device, self.core1._version, self.core2._version, bias_version) + if self._dense_eval_key == cache_key and self._dense_eval_weight is not None: + return + core1 = self.core1.to(device=device, dtype=dtype) + core2 = self.core2.to(device=device, dtype=dtype) + dense = torch.einsum("umr,rvn->mnuv", core1, core2).reshape(self.out_features, self.in_features).contiguous() + self._dense_eval_weight = dense + self._dense_eval_bias = self.bias.to(device=device, dtype=dtype) if self.bias is not None else None + self._dense_eval_key = cache_key + + def train(self, mode: bool = True): + if mode: + self._cached_core1 = None + self._cached_core2 = None + self._cached_bias = None + self._cached_key = None + self._dense_eval_weight = None + self._dense_eval_bias = None + self._dense_eval_key = None + return super().train(mode) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + if not self.training: + self.materialize(x.dtype, x.device) + return F.linear(x, self._dense_eval_weight, self._dense_eval_bias) + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1, core2, bias = self._cast_factors(x) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_eval_chunk_ranks=args.structured_eval_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_eval_chunk_ranks:{args.structured_eval_chunk_ranks} " + f"structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + model.train() + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 31 19:02:53 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 33C P0 76W / 700W | 527MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:0 warmup_steps:0 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:0/0 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.02ms +peak memory allocated: 22735 MiB reserved: 28940 MiB +Serialized model: 101929819 bytes +Code size: 83124 bytes +Total submission size: 102012943 bytes +Serialized model int5_odd+zlib: 4207457 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 4290581 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9416 val_bpb:4.1590 eval_time:851ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94156265 val_bpb:4.15901457 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_cache.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_cache.txt new file mode 100644 index 000000000..d32c45e65 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_cache.txt @@ -0,0 +1,1898 @@ +logs/eval_bench_r256_l16_seq1024_cache.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:0 warmup_steps:0 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:0/0 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.02ms +peak memory allocated: 2822 MiB reserved: 3110 MiB +Serialized model: 101939739 bytes +Code size: 81599 bytes +Total submission size: 102021338 bytes +Serialized model int5_odd+zlib: 7256271 bytes (payload:14157417 raw_torch:14260195 payload_ratio:11.71x) +Total submission size int5_odd+zlib: 7337870 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9416 val_bpb:4.1590 eval_time:99311ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94156116 val_bpb:4.15901368 +s.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_eval_chunk_ranks = int(os.environ.get("STRUCTURED_EVAL_CHUNK_RANKS", 4)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias=bias) + self._cached_weight: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int | None] | None = None + + def _cast_params(self, x: Tensor) -> tuple[Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.weight._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_weight, self._cached_bias + weight = self.weight.to(dtype=x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_weight = weight + self._cached_bias = bias + self._cached_key = cache_key + return weight, bias + + def forward(self, x: Tensor) -> Tensor: + weight, bias = self._cast_params(x) + return F.linear(x, weight, bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + eval_chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.eval_chunk_ranks = max(int(eval_chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + self._cached_core1: Tensor | None = None + self._cached_core2: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def _cast_factors(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.core1._version, self.core2._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_core1, self._cached_core2, self._cached_bias + core1 = self.core1.to(dtype=x.dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x.dtype).contiguous() + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_core1 = core1 + self._cached_core2 = core2 + self._cached_bias = bias + self._cached_key = cache_key + return core1, core2, bias + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1, core2, bias = self._cast_factors(x) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + if torch.is_inference_mode_enabled() and self.eval_chunk_ranks > 1: + for start in range(0, self.rank, self.eval_chunk_ranks): + end = min(start + self.eval_chunk_ranks, self.rank) + left = torch.matmul(x2_t.unsqueeze(1), core1[start:end].unsqueeze(0)) + y = y + torch.matmul(left.transpose(-1, -2), core2[start:end].unsqueeze(0)).sum(dim=1) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_eval_chunk_ranks=args.structured_eval_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_eval_chunk_ranks:{args.structured_eval_chunk_ranks} " + f"structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 31 18:27:11 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 35C P0 92W / 700W | 1185MiB / 81559MiB | 2% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:0 warmup_steps:0 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:0/0 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.02ms +peak memory allocated: 2822 MiB reserved: 3110 MiB +Serialized model: 101939739 bytes +Code size: 81599 bytes +Total submission size: 102021338 bytes +Serialized model int5_odd+zlib: 7256271 bytes (payload:14157417 raw_torch:14260195 payload_ratio:11.71x) +Total submission size int5_odd+zlib: 7337870 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9416 val_bpb:4.1590 eval_time:99311ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94156116 val_bpb:4.15901368 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_cache_vb1048576.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_cache_vb1048576.txt new file mode 100644 index 000000000..fe6bf6954 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_cache_vb1048576.txt @@ -0,0 +1,1898 @@ +logs/eval_bench_r256_l16_seq1024_cache_vb1048576.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:0 warmup_steps:0 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:0/0 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.02ms +peak memory allocated: 41218 MiB reserved: 45380 MiB +Serialized model: 101939739 bytes +Code size: 81599 bytes +Total submission size: 102021338 bytes +Serialized model int5_odd+zlib: 7256271 bytes (payload:14157417 raw_torch:14260195 payload_ratio:11.71x) +Total submission size int5_odd+zlib: 7337870 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9416 val_bpb:4.1590 eval_time:96988ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94156265 val_bpb:4.15901457 +environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_eval_chunk_ranks = int(os.environ.get("STRUCTURED_EVAL_CHUNK_RANKS", 4)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias=bias) + self._cached_weight: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int | None] | None = None + + def _cast_params(self, x: Tensor) -> tuple[Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.weight._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_weight, self._cached_bias + weight = self.weight.to(dtype=x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_weight = weight + self._cached_bias = bias + self._cached_key = cache_key + return weight, bias + + def forward(self, x: Tensor) -> Tensor: + weight, bias = self._cast_params(x) + return F.linear(x, weight, bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + eval_chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.eval_chunk_ranks = max(int(eval_chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + self._cached_core1: Tensor | None = None + self._cached_core2: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def _cast_factors(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.core1._version, self.core2._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_core1, self._cached_core2, self._cached_bias + core1 = self.core1.to(dtype=x.dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x.dtype).contiguous() + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_core1 = core1 + self._cached_core2 = core2 + self._cached_bias = bias + self._cached_key = cache_key + return core1, core2, bias + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1, core2, bias = self._cast_factors(x) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + if torch.is_inference_mode_enabled() and self.eval_chunk_ranks > 1: + for start in range(0, self.rank, self.eval_chunk_ranks): + end = min(start + self.eval_chunk_ranks, self.rank) + left = torch.matmul(x2_t.unsqueeze(1), core1[start:end].unsqueeze(0)) + y = y + torch.matmul(left.transpose(-1, -2), core2[start:end].unsqueeze(0)).sum(dim=1) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_eval_chunk_ranks=args.structured_eval_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_eval_chunk_ranks:{args.structured_eval_chunk_ranks} " + f"structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 31 18:38:29 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 40C P0 90W / 700W | 1185MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:0 warmup_steps:0 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:0/0 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.02ms +peak memory allocated: 41218 MiB reserved: 45380 MiB +Serialized model: 101939739 bytes +Code size: 81599 bytes +Total submission size: 102021338 bytes +Serialized model int5_odd+zlib: 7256271 bytes (payload:14157417 raw_torch:14260195 payload_ratio:11.71x) +Total submission size int5_odd+zlib: 7337870 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9416 val_bpb:4.1590 eval_time:96988ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94156265 val_bpb:4.15901457 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_cache_vb262144.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_cache_vb262144.txt new file mode 100644 index 000000000..aa89971c8 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_cache_vb262144.txt @@ -0,0 +1,1898 @@ +logs/eval_bench_r256_l16_seq1024_cache_vb262144.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:0 warmup_steps:0 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:0/0 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.02ms +peak memory allocated: 10506 MiB reserved: 11558 MiB +Serialized model: 101939739 bytes +Code size: 81599 bytes +Total submission size: 102021338 bytes +Serialized model int5_odd+zlib: 7256271 bytes (payload:14157417 raw_torch:14260195 payload_ratio:11.71x) +Total submission size int5_odd+zlib: 7337870 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9416 val_bpb:4.1590 eval_time:97322ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94156384 val_bpb:4.15901528 +environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_eval_chunk_ranks = int(os.environ.get("STRUCTURED_EVAL_CHUNK_RANKS", 4)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias=bias) + self._cached_weight: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int | None] | None = None + + def _cast_params(self, x: Tensor) -> tuple[Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.weight._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_weight, self._cached_bias + weight = self.weight.to(dtype=x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_weight = weight + self._cached_bias = bias + self._cached_key = cache_key + return weight, bias + + def forward(self, x: Tensor) -> Tensor: + weight, bias = self._cast_params(x) + return F.linear(x, weight, bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + eval_chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.eval_chunk_ranks = max(int(eval_chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + self._cached_core1: Tensor | None = None + self._cached_core2: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def _cast_factors(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.core1._version, self.core2._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_core1, self._cached_core2, self._cached_bias + core1 = self.core1.to(dtype=x.dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x.dtype).contiguous() + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_core1 = core1 + self._cached_core2 = core2 + self._cached_bias = bias + self._cached_key = cache_key + return core1, core2, bias + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1, core2, bias = self._cast_factors(x) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + if torch.is_inference_mode_enabled() and self.eval_chunk_ranks > 1: + for start in range(0, self.rank, self.eval_chunk_ranks): + end = min(start + self.eval_chunk_ranks, self.rank) + left = torch.matmul(x2_t.unsqueeze(1), core1[start:end].unsqueeze(0)) + y = y + torch.matmul(left.transpose(-1, -2), core2[start:end].unsqueeze(0)).sum(dim=1) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_eval_chunk_ranks=args.structured_eval_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_eval_chunk_ranks:{args.structured_eval_chunk_ranks} " + f"structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 31 18:30:58 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 40C P0 96W / 700W | 1185MiB / 81559MiB | 3% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:0 warmup_steps:0 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:0/0 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.02ms +peak memory allocated: 10506 MiB reserved: 11558 MiB +Serialized model: 101939739 bytes +Code size: 81599 bytes +Total submission size: 102021338 bytes +Serialized model int5_odd+zlib: 7256271 bytes (payload:14157417 raw_torch:14260195 payload_ratio:11.71x) +Total submission size int5_odd+zlib: 7337870 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9416 val_bpb:4.1590 eval_time:97322ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94156384 val_bpb:4.15901528 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_cache_vb524288.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_cache_vb524288.txt new file mode 100644 index 000000000..47c9e088b --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_cache_vb524288.txt @@ -0,0 +1,1898 @@ +logs/eval_bench_r256_l16_seq1024_cache_vb524288.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:0 warmup_steps:0 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:0/0 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.02ms +peak memory allocated: 20750 MiB reserved: 22836 MiB +Serialized model: 101939739 bytes +Code size: 81599 bytes +Total submission size: 102021338 bytes +Serialized model int5_odd+zlib: 7256271 bytes (payload:14157417 raw_torch:14260195 payload_ratio:11.71x) +Total submission size int5_odd+zlib: 7337870 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9416 val_bpb:4.1590 eval_time:97090ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94156289 val_bpb:4.15901471 +environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_eval_chunk_ranks = int(os.environ.get("STRUCTURED_EVAL_CHUNK_RANKS", 4)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias=bias) + self._cached_weight: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int | None] | None = None + + def _cast_params(self, x: Tensor) -> tuple[Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.weight._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_weight, self._cached_bias + weight = self.weight.to(dtype=x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_weight = weight + self._cached_bias = bias + self._cached_key = cache_key + return weight, bias + + def forward(self, x: Tensor) -> Tensor: + weight, bias = self._cast_params(x) + return F.linear(x, weight, bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + eval_chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.eval_chunk_ranks = max(int(eval_chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + self._cached_core1: Tensor | None = None + self._cached_core2: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def _cast_factors(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.core1._version, self.core2._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_core1, self._cached_core2, self._cached_bias + core1 = self.core1.to(dtype=x.dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x.dtype).contiguous() + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_core1 = core1 + self._cached_core2 = core2 + self._cached_bias = bias + self._cached_key = cache_key + return core1, core2, bias + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1, core2, bias = self._cast_factors(x) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + if torch.is_inference_mode_enabled() and self.eval_chunk_ranks > 1: + for start in range(0, self.rank, self.eval_chunk_ranks): + end = min(start + self.eval_chunk_ranks, self.rank) + left = torch.matmul(x2_t.unsqueeze(1), core1[start:end].unsqueeze(0)) + y = y + torch.matmul(left.transpose(-1, -2), core2[start:end].unsqueeze(0)).sum(dim=1) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_eval_chunk_ranks=args.structured_eval_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_eval_chunk_ranks:{args.structured_eval_chunk_ranks} " + f"structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 31 18:34:45 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 40C P0 95W / 700W | 1185MiB / 81559MiB | 3% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:0 warmup_steps:0 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:0/0 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.02ms +peak memory allocated: 20750 MiB reserved: 22836 MiB +Serialized model: 101939739 bytes +Code size: 81599 bytes +Total submission size: 102021338 bytes +Serialized model int5_odd+zlib: 7256271 bytes (payload:14157417 raw_torch:14260195 payload_ratio:11.71x) +Total submission size int5_odd+zlib: 7337870 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9416 val_bpb:4.1590 eval_time:97090ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94156289 val_bpb:4.15901471 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_vb262144.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_vb262144.txt new file mode 100644 index 000000000..c27ef9846 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/eval_bench_r256_l16_seq1024_vb262144.txt @@ -0,0 +1,1862 @@ +logs/eval_bench_r256_l16_seq1024_vb262144.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:0 warmup_steps:0 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:0/0 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.02ms +peak memory allocated: 10481 MiB reserved: 11532 MiB +Serialized model: 101929819 bytes +Code size: 79702 bytes +Total submission size: 102009521 bytes +Serialized model int5_odd+zlib: 4207457 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 4287159 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9416 val_bpb:4.1590 eval_time:102007ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94156384 val_bpb:4.15901528 +.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_eval_chunk_ranks = int(os.environ.get("STRUCTURED_EVAL_CHUNK_RANKS", 4)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + eval_chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.eval_chunk_ranks = max(int(eval_chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1 = self.core1.to(dtype=x_dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x_dtype).contiguous() + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + if torch.is_inference_mode_enabled() and self.eval_chunk_ranks > 1: + for start in range(0, self.rank, self.eval_chunk_ranks): + end = min(start + self.eval_chunk_ranks, self.rank) + left = torch.matmul(x2_t.unsqueeze(1), core1[start:end].unsqueeze(0)) + y = y + torch.matmul(left.transpose(-1, -2), core2[start:end].unsqueeze(0)).sum(dim=1) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_eval_chunk_ranks=args.structured_eval_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_eval_chunk_ranks:{args.structured_eval_chunk_ranks} " + f"structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 31 18:17:57 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 39C P0 95W / 700W | 1185MiB / 81559MiB | 3% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:0 warmup_steps:0 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:0/0 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.02ms +peak memory allocated: 10481 MiB reserved: 11532 MiB +Serialized model: 101929819 bytes +Code size: 79702 bytes +Total submission size: 102009521 bytes +Serialized model int5_odd+zlib: 4207457 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 4287159 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9416 val_bpb:4.1590 eval_time:102007ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94156384 val_bpb:4.15901528 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/extreme_accum.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/extreme_accum.txt new file mode 100644 index 000000000..96ba4e1dd --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/extreme_accum.txt @@ -0,0 +1,3639 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 13:23:56 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 43C P5 21W / 170W | 252MiB / 12288MiB | 21% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 127MiB | +| 0 N/A N/A 6226 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:2048 +val_token_limit:2048 +model_params:8507464 +world_size:1 grad_accum_steps:64 train_microbatch_tokens:4096 effective_train_batch_tokens:262144 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:True +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:262144 effective_train_batch_tokens:262144 train_seq_len:256 iterations:1 warmup_steps:0 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:16.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:1/1 train_loss:6.9353 ce_loss:6.9353 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:88399ms step_avg:88399.25ms +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 13:26:03 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 40C P8 21W / 170W | 252MiB / 12288MiB | 33% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 127MiB | +| 0 N/A N/A 6347 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:2048 +val_token_limit:2048 +model_params:8507464 +world_size:1 grad_accum_steps:64 train_microbatch_tokens:4096 effective_train_batch_tokens:262144 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:True +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:262144 effective_train_batch_tokens:262144 train_seq_len:256 iterations:1 warmup_steps:0 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:16.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +step:1/1 train_loss:6.9353 ce_loss:6.9353 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:27580ms step_avg:27580.10ms +step:1/1 val_loss:12.2633 val_bpb:6.6411 train_time:27582ms step_avg:27581.51ms +peak memory allocated: 4617 MiB reserved: 4876 MiB +Serialized model: 33022259 bytes +Code size: 75577 bytes +Total submission size: 33097836 bytes +Serialized model int5_odd+zlib: 2135631 bytes (payload:2942927 raw_torch:2983011 payload_ratio:11.21x) +Total submission size int5_odd+zlib: 2211208 bytes +final_int5_odd_zlib_roundtrip val_loss:11.7690 val_bpb:6.3734 eval_time:86ms +final_int5_odd_zlib_roundtrip_exact val_loss:11.76902580 val_bpb:6.37339220 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/h100_real_r256_l16_seq1024_mb2048_bmm_auto.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/h100_real_r256_l16_seq1024_mb2048_bmm_auto.txt new file mode 100644 index 000000000..1ee1731e7 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/h100_real_r256_l16_seq1024_mb2048_bmm_auto.txt @@ -0,0 +1,1831 @@ +logs/h100_real_r256_l16_seq1024_mb2048_bmm_auto.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:20000 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:1/20000 train_loss:6.9404 ce_loss:6.9404 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:19008ms step_avg:19008.35ms +step:2/20000 train_loss:12.7072 ce_loss:12.7072 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:38408ms step_avg:19203.92ms +step:3/20000 train_loss:12.5926 ce_loss:12.5926 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:58075ms step_avg:19358.32ms +step:4/20000 train_loss:12.3761 ce_loss:12.3761 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:77664ms step_avg:19415.98ms +step:5/20000 train_loss:12.2059 ce_loss:12.2059 rate_mul:0.111 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:97700ms step_avg:19540.07ms +step:6/20000 train_loss:11.9154 ce_loss:11.9154 rate_mul:0.222 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:117914ms step_avg:19652.39ms +step:7/20000 train_loss:11.7523 ce_loss:11.7523 rate_mul:0.333 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:138027ms step_avg:19718.08ms +step:8/20000 train_loss:11.5979 ce_loss:11.5979 rate_mul:0.444 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:158475ms step_avg:19809.35ms +step:9/20000 train_loss:11.1889 ce_loss:11.1889 rate_mul:0.556 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:178761ms step_avg:19862.31ms +step:10/20000 train_loss:11.1120 ce_loss:11.1120 rate_mul:0.667 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:198797ms step_avg:19879.72ms +step:20/20000 train_loss:9.5722 ce_loss:9.5722 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:398781ms step_avg:19939.03ms +step:30/20000 train_loss:9.0810 ce_loss:9.0809 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:598799ms step_avg:19959.97ms +step:31/20000 val_loss:9.0928 val_bpb:5.4479 train_time:618760ms step_avg:19960.00ms +stopping_early: wallclock_cap train_time:618760ms step:31/20000 +peak memory allocated: 50388 MiB reserved: 53230 MiB +Serialized model: 101929819 bytes +Code size: 78465 bytes +Total submission size: 102008284 bytes +Serialized model int5_odd+zlib: 5184839 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 5263304 bytes +final_int5_odd_zlib_roundtrip val_loss:9.8767 val_bpb:5.9176 eval_time:103984ms +final_int5_odd_zlib_roundtrip_exact val_loss:9.87668240 val_bpb:5.91758197 +0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1 = self.core1.to(dtype=x_dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x_dtype).contiguous() + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 31 17:12:30 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 35C P0 90W / 700W | 1185MiB / 81559MiB | 6% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:20000 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:1/20000 train_loss:6.9404 ce_loss:6.9404 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:19008ms step_avg:19008.35ms +step:2/20000 train_loss:12.7072 ce_loss:12.7072 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:38408ms step_avg:19203.92ms +step:3/20000 train_loss:12.5926 ce_loss:12.5926 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:58075ms step_avg:19358.32ms +step:4/20000 train_loss:12.3761 ce_loss:12.3761 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:77664ms step_avg:19415.98ms +step:5/20000 train_loss:12.2059 ce_loss:12.2059 rate_mul:0.111 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:97700ms step_avg:19540.07ms +step:6/20000 train_loss:11.9154 ce_loss:11.9154 rate_mul:0.222 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:117914ms step_avg:19652.39ms +step:7/20000 train_loss:11.7523 ce_loss:11.7523 rate_mul:0.333 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:138027ms step_avg:19718.08ms +step:8/20000 train_loss:11.5979 ce_loss:11.5979 rate_mul:0.444 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:158475ms step_avg:19809.35ms +step:9/20000 train_loss:11.1889 ce_loss:11.1889 rate_mul:0.556 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:178761ms step_avg:19862.31ms +step:10/20000 train_loss:11.1120 ce_loss:11.1120 rate_mul:0.667 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:198797ms step_avg:19879.72ms +step:20/20000 train_loss:9.5722 ce_loss:9.5722 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:398781ms step_avg:19939.03ms +step:30/20000 train_loss:9.0810 ce_loss:9.0809 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:598799ms step_avg:19959.97ms +step:31/20000 val_loss:9.0928 val_bpb:5.4479 train_time:618760ms step_avg:19960.00ms +stopping_early: wallclock_cap train_time:618760ms step:31/20000 +peak memory allocated: 50388 MiB reserved: 53230 MiB +Serialized model: 101929819 bytes +Code size: 78465 bytes +Total submission size: 102008284 bytes +Serialized model int5_odd+zlib: 5184839 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 5263304 bytes +final_int5_odd_zlib_roundtrip val_loss:9.8767 val_bpb:5.9176 eval_time:103984ms +final_int5_odd_zlib_roundtrip_exact val_loss:9.87668240 val_bpb:5.91758197 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/h100_real_r256_l16_seq1024_mb2048_materialized.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/h100_real_r256_l16_seq1024_mb2048_materialized.txt new file mode 100644 index 000000000..2d659424c --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/h100_real_r256_l16_seq1024_mb2048_materialized.txt @@ -0,0 +1,1923 @@ +logs/h100_real_r256_l16_seq1024_mb2048_materialized.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:20000 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:0/20000 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.03ms +step:1/20000 train_loss:6.9404 ce_loss:6.9404 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:18556ms step_avg:18556.11ms +step:2/20000 train_loss:12.7072 ce_loss:12.7072 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:37651ms step_avg:18825.71ms +step:3/20000 train_loss:12.5900 ce_loss:12.5900 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:56956ms step_avg:18985.42ms +step:4/20000 train_loss:12.3686 ce_loss:12.3686 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:76316ms step_avg:19079.05ms +step:5/20000 train_loss:12.1946 ce_loss:12.1946 rate_mul:0.097 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:96008ms step_avg:19201.52ms +step:6/20000 train_loss:11.8978 ce_loss:11.8978 rate_mul:0.204 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:115743ms step_avg:19290.47ms +step:7/20000 train_loss:11.7299 ce_loss:11.7299 rate_mul:0.312 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:135345ms step_avg:19334.94ms +step:8/20000 train_loss:11.5707 ce_loss:11.5707 rate_mul:0.419 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:154822ms step_avg:19352.71ms +step:9/20000 train_loss:11.1523 ce_loss:11.1523 rate_mul:0.527 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:174289ms step_avg:19365.47ms +step:10/20000 train_loss:11.0726 ce_loss:11.0726 rate_mul:0.667 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:194115ms step_avg:19411.48ms +step:20/20000 train_loss:9.4721 ce_loss:9.4721 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:390448ms step_avg:19522.42ms +step:30/20000 train_loss:8.9153 ce_loss:8.9152 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:584228ms step_avg:19474.25ms +step:31/20000 val_loss:8.9222 val_bpb:5.3457 train_time:603522ms step_avg:19468.45ms +stopping_early: wallclock_cap train_time:603522ms step:31/20000 +peak memory allocated: 50290 MiB reserved: 78558 MiB +Serialized model: 101929819 bytes +Code size: 83124 bytes +Total submission size: 102012943 bytes +Serialized model int5_odd+zlib: 5184543 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 5267667 bytes +final_int5_odd_zlib_roundtrip val_loss:9.8273 val_bpb:5.8880 eval_time:854ms +final_int5_odd_zlib_roundtrip_exact val_loss:9.82734489 val_bpb:5.88802157 +", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_eval_chunk_ranks = int(os.environ.get("STRUCTURED_EVAL_CHUNK_RANKS", 1)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + eval_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + eval_module = model.module if isinstance(model, DDP) else model + for module in eval_module.modules(): + if isinstance(module, StructuredLinear): + module.materialize(eval_dtype, device) + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias=bias) + self._cached_weight: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int | None] | None = None + + def _cast_params(self, x: Tensor) -> tuple[Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.weight._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_weight, self._cached_bias + weight = self.weight.to(dtype=x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_weight = weight + self._cached_bias = bias + self._cached_key = cache_key + return weight, bias + + def train(self, mode: bool = True): + if mode: + self._cached_weight = None + self._cached_bias = None + self._cached_key = None + return super().train(mode) + + def forward(self, x: Tensor) -> Tensor: + weight, bias = self._cast_params(x) + return F.linear(x, weight, bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + eval_chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.eval_chunk_ranks = max(int(eval_chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + self._cached_core1: Tensor | None = None + self._cached_core2: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + self._dense_eval_weight: Tensor | None = None + self._dense_eval_bias: Tensor | None = None + self._dense_eval_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def _cast_factors(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.core1._version, self.core2._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_core1, self._cached_core2, self._cached_bias + core1 = self.core1.to(dtype=x.dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x.dtype).contiguous() + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_core1 = core1 + self._cached_core2 = core2 + self._cached_bias = bias + self._cached_key = cache_key + return core1, core2, bias + + def materialize(self, dtype: torch.dtype, device: torch.device) -> None: + if self.impl is not None: + return + bias_version = self.bias._version if self.bias is not None else None + cache_key = (dtype, device, self.core1._version, self.core2._version, bias_version) + if self._dense_eval_key == cache_key and self._dense_eval_weight is not None: + return + core1 = self.core1.to(device=device, dtype=dtype) + core2 = self.core2.to(device=device, dtype=dtype) + dense = torch.einsum("umr,rvn->mnuv", core1, core2).reshape(self.out_features, self.in_features).contiguous() + self._dense_eval_weight = dense + self._dense_eval_bias = self.bias.to(device=device, dtype=dtype) if self.bias is not None else None + self._dense_eval_key = cache_key + + def train(self, mode: bool = True): + if mode: + self._cached_core1 = None + self._cached_core2 = None + self._cached_bias = None + self._cached_key = None + self._dense_eval_weight = None + self._dense_eval_bias = None + self._dense_eval_key = None + return super().train(mode) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + if not self.training: + self.materialize(x.dtype, x.device) + return F.linear(x, self._dense_eval_weight, self._dense_eval_bias) + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1, core2, bias = self._cast_factors(x) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_eval_chunk_ranks=args.structured_eval_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_eval_chunk_ranks:{args.structured_eval_chunk_ranks} " + f"structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + model.train() + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 31 19:03:36 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 33C P0 76W / 700W | 527MiB / 81559MiB | 2% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:20000 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:0/20000 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.03ms +step:1/20000 train_loss:6.9404 ce_loss:6.9404 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:18556ms step_avg:18556.11ms +step:2/20000 train_loss:12.7072 ce_loss:12.7072 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:37651ms step_avg:18825.71ms +step:3/20000 train_loss:12.5900 ce_loss:12.5900 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:56956ms step_avg:18985.42ms +step:4/20000 train_loss:12.3686 ce_loss:12.3686 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:76316ms step_avg:19079.05ms +step:5/20000 train_loss:12.1946 ce_loss:12.1946 rate_mul:0.097 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:96008ms step_avg:19201.52ms +step:6/20000 train_loss:11.8978 ce_loss:11.8978 rate_mul:0.204 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:115743ms step_avg:19290.47ms +step:7/20000 train_loss:11.7299 ce_loss:11.7299 rate_mul:0.312 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:135345ms step_avg:19334.94ms +step:8/20000 train_loss:11.5707 ce_loss:11.5707 rate_mul:0.419 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:154822ms step_avg:19352.71ms +step:9/20000 train_loss:11.1523 ce_loss:11.1523 rate_mul:0.527 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:174289ms step_avg:19365.47ms +step:10/20000 train_loss:11.0726 ce_loss:11.0726 rate_mul:0.667 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:194115ms step_avg:19411.48ms +step:20/20000 train_loss:9.4721 ce_loss:9.4721 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:390448ms step_avg:19522.42ms +step:30/20000 train_loss:8.9153 ce_loss:8.9152 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:584228ms step_avg:19474.25ms +step:31/20000 val_loss:8.9222 val_bpb:5.3457 train_time:603522ms step_avg:19468.45ms +stopping_early: wallclock_cap train_time:603522ms step:31/20000 +peak memory allocated: 50290 MiB reserved: 78558 MiB +Serialized model: 101929819 bytes +Code size: 83124 bytes +Total submission size: 102012943 bytes +Serialized model int5_odd+zlib: 5184543 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 5267667 bytes +final_int5_odd_zlib_roundtrip val_loss:9.8273 val_bpb:5.8880 eval_time:854ms +final_int5_odd_zlib_roundtrip_exact val_loss:9.82734489 val_bpb:5.88802157 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/init_mup_bench.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/init_mup_bench.txt new file mode 100644 index 000000000..707cc4763 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/init_mup_bench.txt @@ -0,0 +1,1702 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:17:56 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 37C P8 20W / 170W | 242MiB / 12288MiB | 9% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 4736 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:8507464 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:True +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:12 warmup_steps:1 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/1 +step:1/12 train_loss:6.9357 ce_loss:6.9357 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:1630ms step_avg:1630.49ms +step:2/12 train_loss:14.0267 ce_loss:14.0267 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:3244ms step_avg:1621.75ms +step:3/12 train_loss:13.2023 ce_loss:13.2023 rate_mul:0.222 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:5327ms step_avg:1775.80ms +step:4/12 train_loss:11.5930 ce_loss:11.5930 rate_mul:0.500 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:7378ms step_avg:1844.48ms +step:5/12 train_loss:9.8433 ce_loss:9.8433 rate_mul:0.778 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:9399ms step_avg:1879.81ms +step:6/12 train_loss:9.0779 ce_loss:9.0778 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:11429ms step_avg:1904.79ms +step:7/12 train_loss:8.0961 ce_loss:8.0961 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:13451ms step_avg:1921.58ms +step:8/12 train_loss:7.3761 ce_loss:7.3761 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:15470ms step_avg:1933.71ms +step:9/12 train_loss:6.8965 ce_loss:6.8964 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:17489ms step_avg:1943.24ms +step:10/12 train_loss:6.8692 ce_loss:6.8692 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:19518ms step_avg:1951.79ms +step:12/12 val_loss:6.3551 val_bpb:3.7575 train_time:23567ms step_avg:1963.91ms +peak memory allocated: 736 MiB reserved: 1422 MiB +Serialized model: 33022259 bytes +Code size: 69179 bytes +Total submission size: 33091438 bytes +Serialized model int5_odd+zlib: 2137083 bytes (payload:2942927 raw_torch:2983011 payload_ratio:11.21x) +Total submission size int5_odd+zlib: 2206262 bytes +final_int5_odd_zlib_roundtrip val_loss:6.7529 val_bpb:3.9927 eval_time:1852ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.75291294 val_bpb:3.99273050 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/init_xavier_bench.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/init_xavier_bench.txt new file mode 100644 index 000000000..a2fde341e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/init_xavier_bench.txt @@ -0,0 +1,1702 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:18:46 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 36C P8 20W / 170W | 242MiB / 12288MiB | 24% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 4780 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:8507464 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:xavier compile_structured_mlp:True +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:12 warmup_steps:1 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/1 +step:1/12 train_loss:6.9357 ce_loss:6.9357 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:1589ms step_avg:1588.73ms +step:2/12 train_loss:14.0267 ce_loss:14.0267 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:3180ms step_avg:1589.78ms +step:3/12 train_loss:13.1840 ce_loss:13.1840 rate_mul:0.222 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:5254ms step_avg:1751.37ms +step:4/12 train_loss:11.5516 ce_loss:11.5516 rate_mul:0.500 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:7239ms step_avg:1809.81ms +step:5/12 train_loss:9.8005 ce_loss:9.8005 rate_mul:0.778 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:9200ms step_avg:1839.91ms +step:6/12 train_loss:9.0405 ce_loss:9.0405 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:11167ms step_avg:1861.10ms +step:7/12 train_loss:8.0618 ce_loss:8.0618 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:13158ms step_avg:1879.72ms +step:8/12 train_loss:7.3487 ce_loss:7.3487 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:15139ms step_avg:1892.31ms +step:9/12 train_loss:6.8810 ce_loss:6.8809 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:17121ms step_avg:1902.34ms +step:10/12 train_loss:6.8680 ce_loss:6.8680 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:19105ms step_avg:1910.46ms +step:12/12 val_loss:6.3586 val_bpb:3.7596 train_time:23050ms step_avg:1920.86ms +peak memory allocated: 736 MiB reserved: 1422 MiB +Serialized model: 33022259 bytes +Code size: 69179 bytes +Total submission size: 33091438 bytes +Serialized model int5_odd+zlib: 2135864 bytes (payload:2942927 raw_torch:2983011 payload_ratio:11.21x) +Total submission size int5_odd+zlib: 2205043 bytes +final_int5_odd_zlib_roundtrip val_loss:6.7733 val_bpb:4.0048 eval_time:1864ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.77326751 val_bpb:4.00476535 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/lambda_0p00002.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/lambda_0p00002.txt new file mode 100644 index 000000000..36e55e21a --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/lambda_0p00002.txt @@ -0,0 +1,1679 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + self.fc = torch.compile(self.fc, mode="reduce-overhead", fullgraph=False, dynamic=False) + self.proj = torch.compile(self.proj, mode="reduce-overhead", fullgraph=False, dynamic=False) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:07:38 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 37% 43C P0 36W / 170W | 242MiB / 12288MiB | 11% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 3772 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:262144 +val_token_limit:262144 +model_params:8507464 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:False +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:8192 train_seq_len:256 iterations:12 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/12 val_loss:6.9348 val_bpb:4.1066 train_time:0ms step_avg:0.02ms +step:1/12 train_loss:6.9402 ce_loss:6.9402 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:4693ms step_avg:4692.85ms +step:2/12 train_loss:13.5286 ce_loss:13.5286 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:9317ms step_avg:4658.66ms +step:3/12 train_loss:12.9297 ce_loss:12.9297 rate_mul:0.222 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:14420ms step_avg:4806.73ms +step:4/12 train_loss:12.2689 ce_loss:12.2689 rate_mul:0.500 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:19499ms step_avg:4874.74ms +step:5/12 train_loss:11.4860 ce_loss:11.4860 rate_mul:0.778 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:24563ms step_avg:4912.60ms +step:6/12 train_loss:10.5960 ce_loss:10.5959 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:29612ms step_avg:4935.34ms +step:7/12 train_loss:9.6664 ce_loss:9.6663 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:34666ms step_avg:4952.32ms +step:8/12 train_loss:8.9344 ce_loss:8.9344 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:39716ms step_avg:4964.48ms +step:9/12 train_loss:8.3407 ce_loss:8.3407 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:44757ms step_avg:4973.03ms +step:10/12 train_loss:7.6963 ce_loss:7.6963 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:49759ms step_avg:4975.87ms +step:12/12 val_loss:6.6366 val_bpb:3.9300 train_time:59886ms step_avg:4990.52ms +peak memory allocated: 2103 MiB reserved: 2126 MiB +Serialized model: 33020979 bytes +Code size: 68319 bytes +Total submission size: 33089298 bytes +Serialized model int5_odd+zlib: 2161240 bytes (payload:2942927 raw_torch:2982627 payload_ratio:11.21x) +Total submission size int5_odd+zlib: 2229559 bytes +final_int5_odd_zlib_roundtrip val_loss:6.8893 val_bpb:4.0797 eval_time:13143ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.88928160 val_bpb:4.07966142 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/lambda_0p0002.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/lambda_0p0002.txt new file mode 100644 index 000000000..fb0b16d94 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/lambda_0p0002.txt @@ -0,0 +1,1679 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + self.fc = torch.compile(self.fc, mode="reduce-overhead", fullgraph=False, dynamic=False) + self.proj = torch.compile(self.proj, mode="reduce-overhead", fullgraph=False, dynamic=False) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:05:52 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 38% 45C P0 37W / 170W | 242MiB / 12288MiB | 5% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 3755 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:262144 +val_token_limit:262144 +model_params:8507464 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:False +rate_lambda:0.0002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:8192 train_seq_len:256 iterations:12 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/12 val_loss:6.9348 val_bpb:4.1066 train_time:0ms step_avg:0.02ms +step:1/12 train_loss:6.9402 ce_loss:6.9402 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:4830ms step_avg:4829.51ms +step:2/12 train_loss:13.5286 ce_loss:13.5286 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:9546ms step_avg:4773.13ms +step:3/12 train_loss:12.9434 ce_loss:12.9434 rate_mul:0.222 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:14813ms step_avg:4937.64ms +step:4/12 train_loss:12.2947 ce_loss:12.2945 rate_mul:0.500 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:19991ms step_avg:4997.78ms +step:5/12 train_loss:11.5389 ce_loss:11.5386 rate_mul:0.778 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:25208ms step_avg:5041.65ms +step:6/12 train_loss:10.6648 ce_loss:10.6645 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:30397ms step_avg:5066.10ms +step:7/12 train_loss:9.7527 ce_loss:9.7523 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:35642ms step_avg:5091.73ms +step:8/12 train_loss:9.0302 ce_loss:9.0299 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:40820ms step_avg:5102.45ms +step:9/12 train_loss:8.4368 ce_loss:8.4365 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:46008ms step_avg:5112.03ms +step:10/12 train_loss:7.7901 ce_loss:7.7898 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:51200ms step_avg:5120.03ms +step:12/12 val_loss:6.7022 val_bpb:3.9689 train_time:61594ms step_avg:5132.85ms +peak memory allocated: 2103 MiB reserved: 2126 MiB +Serialized model: 33020979 bytes +Code size: 68319 bytes +Total submission size: 33089298 bytes +Serialized model int5_odd+zlib: 2163132 bytes (payload:2942927 raw_torch:2982627 payload_ratio:11.21x) +Total submission size int5_odd+zlib: 2231451 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9639 val_bpb:4.1239 eval_time:12911ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.96392745 val_bpb:4.12386483 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/lambda_0p002.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/lambda_0p002.txt new file mode 100644 index 000000000..df6cbea9f --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/lambda_0p002.txt @@ -0,0 +1,1679 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + self.fc = torch.compile(self.fc, mode="reduce-overhead", fullgraph=False, dynamic=False) + self.proj = torch.compile(self.proj, mode="reduce-overhead", fullgraph=False, dynamic=False) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:04:08 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 43% 50C P0 42W / 170W | 250MiB / 12288MiB | 6% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 125MiB | +| 0 N/A N/A 3734 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:262144 +val_token_limit:262144 +model_params:8507464 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:False +rate_lambda:0.002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:8192 train_seq_len:256 iterations:12 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/12 val_loss:6.9348 val_bpb:4.1066 train_time:0ms step_avg:0.02ms +step:1/12 train_loss:6.9402 ce_loss:6.9402 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:4810ms step_avg:4810.14ms +step:2/12 train_loss:13.5286 ce_loss:13.5286 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:9460ms step_avg:4730.19ms +step:3/12 train_loss:12.9416 ce_loss:12.9409 rate_mul:0.222 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:14588ms step_avg:4862.83ms +step:4/12 train_loss:12.2879 ce_loss:12.2864 rate_mul:0.500 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:19655ms step_avg:4913.77ms +step:5/12 train_loss:11.5199 ce_loss:11.5176 rate_mul:0.778 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:24717ms step_avg:4943.38ms +step:6/12 train_loss:10.6343 ce_loss:10.6313 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:29795ms step_avg:4965.83ms +step:7/12 train_loss:9.7079 ce_loss:9.7049 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:34837ms step_avg:4976.78ms +step:8/12 train_loss:8.9768 ce_loss:8.9738 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:39922ms step_avg:4990.27ms +step:9/12 train_loss:8.3782 ce_loss:8.3752 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:44986ms step_avg:4998.45ms +step:10/12 train_loss:7.7292 ce_loss:7.7262 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:50016ms step_avg:5001.58ms +step:12/12 val_loss:6.6548 val_bpb:3.9408 train_time:60139ms step_avg:5011.55ms +peak memory allocated: 2103 MiB reserved: 2126 MiB +Serialized model: 33020979 bytes +Code size: 68319 bytes +Total submission size: 33089298 bytes +Serialized model int5_odd+zlib: 2162016 bytes (payload:2942927 raw_torch:2982627 payload_ratio:11.21x) +Total submission size int5_odd+zlib: 2230335 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9093 val_bpb:4.0915 eval_time:13125ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.90930092 val_bpb:4.09151636 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/lambda_high_40.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/lambda_high_40.txt new file mode 100644 index 000000000..0e2387463 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/lambda_high_40.txt @@ -0,0 +1,1706 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:37:17 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 39C P8 19W / 170W | 242MiB / 12288MiB | 41% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 5135 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:8507464 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:True +rate_lambda:0.002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:1 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/1 +step:0/40 val_loss:6.9336 val_bpb:4.1543 train_time:0ms step_avg:0.02ms +step:1/40 train_loss:6.9388 ce_loss:6.9388 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:1603ms step_avg:1603.20ms +step:2/40 train_loss:12.9960 ce_loss:12.9960 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:3216ms step_avg:1608.13ms +step:3/40 train_loss:11.4317 ce_loss:11.4317 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:4815ms step_avg:1604.96ms +step:4/40 train_loss:9.1479 ce_loss:9.1479 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:6413ms step_avg:1603.24ms +step:5/40 train_loss:7.5144 ce_loss:7.5144 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:8016ms step_avg:1603.10ms +step:6/40 train_loss:6.5124 ce_loss:6.5122 rate_mul:0.083 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:10072ms step_avg:1678.70ms +step:7/40 train_loss:6.1390 ce_loss:6.1385 rate_mul:0.167 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:12072ms step_avg:1724.50ms +step:8/40 train_loss:6.0774 ce_loss:6.0766 rate_mul:0.250 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:14058ms step_avg:1757.24ms +step:9/40 train_loss:5.8621 ce_loss:5.8611 rate_mul:0.333 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:16048ms step_avg:1783.14ms +step:10/40 train_loss:5.9833 ce_loss:5.9820 rate_mul:0.417 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:18036ms step_avg:1803.58ms +step:20/40 train_loss:5.6802 ce_loss:5.6772 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:37980ms step_avg:1898.98ms +step:30/40 train_loss:5.5847 ce_loss:5.5817 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:57850ms step_avg:1928.34ms +step:40/40 train_loss:5.4463 ce_loss:5.4433 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.065182 est_model_bits:1325360 train_time:77923ms step_avg:1948.08ms +step:40/40 val_loss:5.4503 val_bpb:3.2655 train_time:77923ms step_avg:1948.09ms +peak memory allocated: 2705 MiB reserved: 5590 MiB +Serialized model: 33022259 bytes +Code size: 69181 bytes +Total submission size: 33091440 bytes +Serialized model int5_odd+zlib: 2083571 bytes (payload:2942927 raw_torch:2983011 payload_ratio:11.21x) +Total submission size int5_odd+zlib: 2152752 bytes +final_int5_odd_zlib_roundtrip val_loss:5.7165 val_bpb:3.4250 eval_time:29615ms +final_int5_odd_zlib_roundtrip_exact val_loss:5.71650304 val_bpb:3.42502411 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/materialize_sanity.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/materialize_sanity.txt new file mode 100644 index 000000000..012b0f08b --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/materialize_sanity.txt @@ -0,0 +1,1956 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_eval_chunk_ranks = int(os.environ.get("STRUCTURED_EVAL_CHUNK_RANKS", 1)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + eval_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + eval_module = model.module if isinstance(model, DDP) else model + for module in eval_module.modules(): + if isinstance(module, StructuredLinear): + module.materialize(eval_dtype, device) + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias=bias) + self._cached_weight: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int | None] | None = None + + def _cast_params(self, x: Tensor) -> tuple[Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.weight._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_weight, self._cached_bias + weight = self.weight.to(dtype=x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_weight = weight + self._cached_bias = bias + self._cached_key = cache_key + return weight, bias + + def train(self, mode: bool = True): + if mode: + self._cached_weight = None + self._cached_bias = None + self._cached_key = None + return super().train(mode) + + def forward(self, x: Tensor) -> Tensor: + weight, bias = self._cast_params(x) + return F.linear(x, weight, bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + eval_chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.eval_chunk_ranks = max(int(eval_chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + self._cached_core1: Tensor | None = None + self._cached_core2: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + self._dense_eval_weight: Tensor | None = None + self._dense_eval_bias: Tensor | None = None + self._dense_eval_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def _cast_factors(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.core1._version, self.core2._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_core1, self._cached_core2, self._cached_bias + core1 = self.core1.to(dtype=x.dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x.dtype).contiguous() + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_core1 = core1 + self._cached_core2 = core2 + self._cached_bias = bias + self._cached_key = cache_key + return core1, core2, bias + + def materialize(self, dtype: torch.dtype, device: torch.device) -> None: + if self.impl is not None: + return + bias_version = self.bias._version if self.bias is not None else None + cache_key = (dtype, device, self.core1._version, self.core2._version, bias_version) + if self._dense_eval_key == cache_key and self._dense_eval_weight is not None: + return + core1 = self.core1.to(device=device, dtype=dtype) + core2 = self.core2.to(device=device, dtype=dtype) + dense = torch.einsum("umr,rvn->mnuv", core1, core2).reshape(self.out_features, self.in_features).contiguous() + self._dense_eval_weight = dense + self._dense_eval_bias = self.bias.to(device=device, dtype=dtype) if self.bias is not None else None + self._dense_eval_key = cache_key + + def train(self, mode: bool = True): + if mode: + self._cached_core1 = None + self._cached_core2 = None + self._cached_bias = None + self._cached_key = None + self._dense_eval_weight = None + self._dense_eval_bias = None + self._dense_eval_key = None + return super().train(mode) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + if not self.training: + self.materialize(x.dtype, x.device) + return F.linear(x, self._dense_eval_weight, self._dense_eval_bias) + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1, core2, bias = self._cast_factors(x) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_eval_chunk_ranks=args.structured_eval_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_eval_chunk_ranks:{args.structured_eval_chunk_ranks} " + f"structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + model.train() + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 18:59:08 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 49C P0 22W / 170W | 224MiB / 12288MiB | 10% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 99MiB | +| 0 N/A N/A 11001 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:8507464 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +dry_run_init_only: exiting after init/logging diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/mock_8xh100_math.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/mock_8xh100_math.txt new file mode 100644 index 000000000..699f65480 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/mock_8xh100_math.txt @@ -0,0 +1,1826 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 13:23:38 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 44C P2 22W / 170W | 252MiB / 12288MiB | 19% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 127MiB | +| 0 N/A N/A 6205 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:2048 +val_token_limit:2048 +model_params:94951616 +world_size:8 grad_accum_steps:2 train_microbatch_tokens:8192 effective_train_batch_tokens:131072 +simulated_world_size:8 simulated_rank:0 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:1024 btt_init:mup compile_structured_mlp:True +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.08485281374238571 head_lr:0.0 matrix_lr:0.05656854249492381 scalar_lr:0.05656854249492381 +train_batch_tokens_target:131072 effective_train_batch_tokens:131072 train_seq_len:256 iterations:40 warmup_steps:8 max_wallclock_seconds:600.000 +lr_scale_method:sqrt lr_reference_batch_tokens:16384 batch_multiplier:8.0000 warmup_scale_method:linear warmup_reference_batch_tokens:16384 +seed:1337 +debug_sharding:world_size:8 global_tokens:131072 grad_accum_steps:2 local_tokens_per_rank:8192 per_rank_span:8193 +debug_sharding_step:0 shared_chunk_token_range:[0,65544) +debug_shard rank:0 token_range:[0,8193) preview:1,896,319,943,956,482 +debug_shard rank:1 token_range:[8193,16386) preview:828,946,345,319,286,951 +debug_shard rank:2 token_range:[16386,24579) preview:949,944,338,648,960,502 +debug_shard rank:3 token_range:[24579,32772) preview:420,267,571,494,440,269 +debug_shard rank:4 token_range:[32772,40965) preview:557,958,814,308,436,960 +debug_shard rank:5 token_range:[40965,49158) preview:280,268,358,946,948,469 +debug_shard rank:6 token_range:[49158,57351) preview:261,418,275,711,933,312 +debug_shard rank:7 token_range:[57351,65544) preview:288,572,487,294,656,951 +dry_run_init_only: exiting after init/logging diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/patch_cache_sanity.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/patch_cache_sanity.txt new file mode 100644 index 000000000..1554df88f --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/patch_cache_sanity.txt @@ -0,0 +1,1921 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_eval_chunk_ranks = int(os.environ.get("STRUCTURED_EVAL_CHUNK_RANKS", 4)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias=bias) + self._cached_weight: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int | None] | None = None + + def _cast_params(self, x: Tensor) -> tuple[Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.weight._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_weight, self._cached_bias + weight = self.weight.to(dtype=x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_weight = weight + self._cached_bias = bias + self._cached_key = cache_key + return weight, bias + + def forward(self, x: Tensor) -> Tensor: + weight, bias = self._cast_params(x) + return F.linear(x, weight, bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + eval_chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.eval_chunk_ranks = max(int(eval_chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + self._cached_core1: Tensor | None = None + self._cached_core2: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def _cast_factors(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.core1._version, self.core2._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_core1, self._cached_core2, self._cached_bias + core1 = self.core1.to(dtype=x.dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x.dtype).contiguous() + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_core1 = core1 + self._cached_core2 = core2 + self._cached_bias = bias + self._cached_key = cache_key + return core1, core2, bias + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1, core2, bias = self._cast_factors(x) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + if torch.is_inference_mode_enabled() and self.eval_chunk_ranks > 1: + for start in range(0, self.rank, self.eval_chunk_ranks): + end = min(start + self.eval_chunk_ranks, self.rank) + left = torch.matmul(x2_t.unsqueeze(1), core1[start:end].unsqueeze(0)) + y = y + torch.matmul(left.transpose(-1, -2), core2[start:end].unsqueeze(0)).sum(dim=1) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_eval_chunk_ranks=args.structured_eval_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_eval_chunk_ranks:{args.structured_eval_chunk_ranks} " + f"structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 18:26:37 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 46C P5 21W / 170W | 224MiB / 12288MiB | 26% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 99MiB | +| 0 N/A N/A 10632 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:8507464 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:4 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +dry_run_init_only: exiting after init/logging diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/patch_eval_sanity.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/patch_eval_sanity.txt new file mode 100644 index 000000000..df4f5a2d9 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/patch_eval_sanity.txt @@ -0,0 +1,1885 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_eval_chunk_ranks = int(os.environ.get("STRUCTURED_EVAL_CHUNK_RANKS", 4)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + eval_chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.eval_chunk_ranks = max(int(eval_chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1 = self.core1.to(dtype=x_dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x_dtype).contiguous() + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + if torch.is_inference_mode_enabled() and self.eval_chunk_ranks > 1: + for start in range(0, self.rank, self.eval_chunk_ranks): + end = min(start + self.eval_chunk_ranks, self.rank) + left = torch.matmul(x2_t.unsqueeze(1), core1[start:end].unsqueeze(0)) + y = y + torch.matmul(left.transpose(-1, -2), core2[start:end].unsqueeze(0)).sum(dim=1) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_eval_chunk_ranks=args.structured_eval_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_eval_chunk_ranks:{args.structured_eval_chunk_ranks} " + f"structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 18:13:00 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 50C P8 21W / 170W | 224MiB / 12288MiB | 31% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 99MiB | +| 0 N/A N/A 10553 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:8507464 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:4 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +dry_run_init_only: exiting after init/logging diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/patch_sanity.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/patch_sanity.txt new file mode 100644 index 000000000..290b0cbc8 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/patch_sanity.txt @@ -0,0 +1,5579 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1 = self.core1.to(dtype=x_dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x_dtype).contiguous() + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + left = torch.matmul(x2_t.unsqueeze(1), core1[start:end].unsqueeze(0)) + y = y + torch.matmul(left.transpose(-1, -2), core2[start:end].unsqueeze(0)).sum(dim=1) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options(structured_compile_mode)) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + mode=structured_compile_mode, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + mode=structured_compile_mode, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 17:05:12 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 49C P5 22W / 170W | 224MiB / 12288MiB | 22% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 99MiB | +| 0 N/A N/A 9685 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1 = self.core1.to(dtype=x_dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x_dtype).contiguous() + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + left = torch.matmul(x2_t.unsqueeze(1), core1[start:end].unsqueeze(0)) + y = y + torch.matmul(left.transpose(-1, -2), core2[start:end].unsqueeze(0)).sum(dim=1) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 17:05:35 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 50C P0 20W / 170W | 224MiB / 12288MiB | 36% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 99MiB | +| 0 N/A N/A 9707 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:8507464 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +dry_run_init_only: exiting after init/logging +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1 = self.core1.to(dtype=x_dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x_dtype).contiguous() + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 17:06:56 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 50C P5 22W / 170W | 224MiB / 12288MiB | 8% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 99MiB | +| 0 N/A N/A 9738 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:8507464 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:32 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +dry_run_init_only: exiting after init/logging diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/probe_r256_l16_seq1024_mb2048.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/probe_r256_l16_seq1024_mb2048.txt new file mode 100644 index 000000000..7737f9381 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/probe_r256_l16_seq1024_mb2048.txt @@ -0,0 +1,1794 @@ +logs/probe_r256_l16_seq1024_mb2048.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:1 warmup_steps:1 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:1/1 train_loss:6.9404 ce_loss:6.9404 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:33066ms step_avg:33066.06ms +step:1/1 val_loss:6.9380 val_bpb:4.1022 train_time:33067ms step_avg:33066.60ms +peak memory allocated: 50323 MiB reserved: 53188 MiB +Serialized model: 101929819 bytes +Code size: 75576 bytes +Total submission size: 102005395 bytes +Serialized model int5_odd+zlib: 5112978 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 5188554 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9418 val_bpb:4.1044 eval_time:5855ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94184875 val_bpb:4.10444078 + Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + self.fc = torch.compile( + self.fc, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + self.proj = torch.compile( + self.proj, + options=compile_options, + fullgraph=False, + dynamic=False, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 31 16:24:44 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 36C P0 92W / 700W | 1185MiB / 81559MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:1 warmup_steps:1 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:1/1 train_loss:6.9404 ce_loss:6.9404 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:33066ms step_avg:33066.06ms +step:1/1 val_loss:6.9380 val_bpb:4.1022 train_time:33067ms step_avg:33066.60ms +peak memory allocated: 50323 MiB reserved: 53188 MiB +Serialized model: 101929819 bytes +Code size: 75576 bytes +Total submission size: 102005395 bytes +Serialized model int5_odd+zlib: 5112978 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 5188554 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9418 val_bpb:4.1044 eval_time:5855ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94184875 val_bpb:4.10444078 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/probe_r256_l16_seq1024_mb2048_bmm2.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/probe_r256_l16_seq1024_mb2048_bmm2.txt new file mode 100644 index 000000000..4a5d78a8d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/probe_r256_l16_seq1024_mb2048_bmm2.txt @@ -0,0 +1,1842 @@ +logs/probe_r256_l16_seq1024_mb2048_bmm2.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:1 warmup_steps:1 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:1/1 train_loss:6.9404 ce_loss:6.9404 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:18770ms step_avg:18770.28ms +step:1/1 val_loss:6.9380 val_bpb:4.1022 train_time:18771ms step_avg:18770.85ms +peak memory allocated: 49813 MiB reserved: 52668 MiB +Serialized model: 101929819 bytes +Code size: 78465 bytes +Total submission size: 102008284 bytes +Serialized model int5_odd+zlib: 5112978 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 5191443 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9418 val_bpb:4.1044 eval_time:6506ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94184875 val_bpb:4.10444078 +(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1 = self.core1.to(dtype=x_dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x_dtype).contiguous() + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 31 17:07:30 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 33C P0 93W / 700W | 1185MiB / 81559MiB | 6% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:1 warmup_steps:1 max_wallclock_seconds:0.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:1/1 train_loss:6.9404 ce_loss:6.9404 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:18770ms step_avg:18770.28ms +step:1/1 val_loss:6.9380 val_bpb:4.1022 train_time:18771ms step_avg:18770.85ms +peak memory allocated: 49813 MiB reserved: 52668 MiB +Serialized model: 101929819 bytes +Code size: 78465 bytes +Total submission size: 102008284 bytes +Serialized model int5_odd+zlib: 5112978 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 5191443 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9418 val_bpb:4.1044 eval_time:6506ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94184875 val_bpb:4.10444078 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/probe_rate_auto_bmm.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/probe_rate_auto_bmm.txt new file mode 100644 index 000000000..bd372a9b8 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/probe_rate_auto_bmm.txt @@ -0,0 +1,1836 @@ +logs/probe_rate_auto_bmm.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:20000 warmup_steps:1 max_wallclock_seconds:130.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:1/20000 train_loss:6.9404 ce_loss:6.9404 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:18753ms step_avg:18752.97ms +step:2/20000 train_loss:12.7073 ce_loss:12.7072 rate_mul:0.222 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:38430ms step_avg:19214.81ms +step:3/20000 train_loss:12.6905 ce_loss:12.6905 rate_mul:0.778 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:58096ms step_avg:19365.29ms +step:4/20000 train_loss:12.6512 ce_loss:12.6512 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:77808ms step_avg:19451.99ms +step:5/20000 train_loss:12.6780 ce_loss:12.6780 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:97504ms step_avg:19500.86ms +step:6/20000 train_loss:12.6050 ce_loss:12.6049 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:117102ms step_avg:19517.05ms +step:7/20000 train_loss:12.6482 ce_loss:12.6481 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:136790ms step_avg:19541.49ms +step:7/20000 val_loss:12.6308 val_bpb:7.4681 train_time:136791ms step_avg:19541.59ms +stopping_early: wallclock_cap train_time:136791ms step:7/20000 +peak memory allocated: 50388 MiB reserved: 53228 MiB +Serialized model: 101929819 bytes +Code size: 78465 bytes +Total submission size: 102008284 bytes +Serialized model int5_odd+zlib: 5180456 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 5258921 bytes +final_int5_odd_zlib_roundtrip val_loss:12.1881 val_bpb:7.2064 eval_time:6512ms +final_int5_odd_zlib_roundtrip_exact val_loss:12.18812561 val_bpb:7.20635691 +viron.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1 = self.core1.to(dtype=x_dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x_dtype).contiguous() + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 31 17:09:00 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 35C P0 88W / 700W | 1185MiB / 81559MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:20000 warmup_steps:1 max_wallclock_seconds:130.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:1/20000 train_loss:6.9404 ce_loss:6.9404 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:18753ms step_avg:18752.97ms +step:2/20000 train_loss:12.7073 ce_loss:12.7072 rate_mul:0.222 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:38430ms step_avg:19214.81ms +step:3/20000 train_loss:12.6905 ce_loss:12.6905 rate_mul:0.778 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:58096ms step_avg:19365.29ms +step:4/20000 train_loss:12.6512 ce_loss:12.6512 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:77808ms step_avg:19451.99ms +step:5/20000 train_loss:12.6780 ce_loss:12.6780 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:97504ms step_avg:19500.86ms +step:6/20000 train_loss:12.6050 ce_loss:12.6049 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:117102ms step_avg:19517.05ms +step:7/20000 train_loss:12.6482 ce_loss:12.6481 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:136790ms step_avg:19541.49ms +step:7/20000 val_loss:12.6308 val_bpb:7.4681 train_time:136791ms step_avg:19541.59ms +stopping_early: wallclock_cap train_time:136791ms step:7/20000 +peak memory allocated: 50388 MiB reserved: 53228 MiB +Serialized model: 101929819 bytes +Code size: 78465 bytes +Total submission size: 102008284 bytes +Serialized model int5_odd+zlib: 5180456 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 5258921 bytes +final_int5_odd_zlib_roundtrip val_loss:12.1881 val_bpb:7.2064 eval_time:6512ms +final_int5_odd_zlib_roundtrip_exact val_loss:12.18812561 val_bpb:7.20635691 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/sanity_dense.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/sanity_dense.txt new file mode 100644 index 000000000..81f6e45fa --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/sanity_dense.txt @@ -0,0 +1,1583 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "1"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.02)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0005)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.2)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.6)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "0"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if t.ndim != 2 or t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor_int5_odd(t) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + out[name] = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if p.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return True + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(m1, self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + std = 1.0 / math.sqrt(self.in_features) + nn.init.normal_(self.core1, mean=0.0, std=std / math.sqrt(self.rank)) + nn.init.normal_(self.core2, mean=0.0, std=std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + tmp = torch.einsum("buv,umr->bmrv", x2, core1) + y = torch.einsum("bmrv,mrwn->bnw", tmp, core2) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, structured_mlp, structured_linear_type, btt_rank, btt_in_shape, btt_out_shape) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 02:44:46 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 42C P8 18W / 170W | 214MiB / 12288MiB | 23% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 975 G /usr/lib/xorg/Xorg 89MiB | +| 0 N/A N/A 4060 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17059912 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:False structured_linear_type:btt_mlp +rate_lambda:0.01 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:8192 train_seq_len:256 iterations:2 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:1/2 train_loss:6.9448 ce_loss:6.9448 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:1128ms step_avg:1128.30ms +step:2/2 train_loss:18.9179 ce_loss:18.8956 rate_mul:1.000 rate_proxy_bits:2.2353 scale_proxy:0.011464 est_model_bits:21094658 train_time:1810ms step_avg:905.15ms diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/scout_R128_L12.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/scout_R128_L12.txt new file mode 100644 index 000000000..60ceb4a82 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/scout_R128_L12.txt @@ -0,0 +1,1668 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + self.fc = torch.compile(self.fc, mode="reduce-overhead", fullgraph=False, dynamic=False) + self.proj = torch.compile(self.proj, mode="reduce-overhead", fullgraph=False, dynamic=False) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 11:55:03 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 0% 39C P8 19W / 170W | 242MiB / 12288MiB | 29% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 3562 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:14707808 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:128 btt_init:mup compile_structured_mlp:False +rate_lambda:0.0002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:0 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/0 val_loss:6.9408 val_bpb:4.1038 train_time:0ms step_avg:0.04ms +peak memory allocated: 335 MiB reserved: 378 MiB +Serialized model: 57835323 bytes +Code size: 68319 bytes +Total submission size: 57903642 bytes +Serialized model int5_odd+zlib: 2571594 bytes (payload:5050587 raw_torch:5103707 payload_ratio:11.44x) +Total submission size int5_odd+zlib: 2639913 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9448 val_bpb:4.1062 eval_time:16556ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94475281 val_bpb:4.10615783 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/scout_R256_L16.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/scout_R256_L16.txt new file mode 100644 index 000000000..23ebb0c9d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/scout_R256_L16.txt @@ -0,0 +1,1668 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + self.fc = torch.compile(self.fc, mode="reduce-overhead", fullgraph=False, dynamic=False) + self.proj = torch.compile(self.proj, mode="reduce-overhead", fullgraph=False, dynamic=False) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 11:55:43 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 39% 46C P3 37W / 170W | 242MiB / 12288MiB | 6% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 3579 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:25727104 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False +rate_lambda:0.0002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:0 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/0 val_loss:6.9401 val_bpb:4.1034 train_time:0ms step_avg:0.02ms +peak memory allocated: 409 MiB reserved: 452 MiB +Serialized model: 101929819 bytes +Code size: 68319 bytes +Total submission size: 101998138 bytes +Serialized model int5_odd+zlib: 4207457 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 4275776 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9441 val_bpb:4.1058 eval_time:43807ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94414002 val_bpb:4.10579551 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/scout_R512_L16.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/scout_R512_L16.txt new file mode 100644 index 000000000..361be7faa --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/scout_R512_L16.txt @@ -0,0 +1,1668 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + self.fc = torch.compile(self.fc, mode="reduce-overhead", fullgraph=False, dynamic=False) + self.proj = torch.compile(self.proj, mode="reduce-overhead", fullgraph=False, dynamic=False) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 11:57:17 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 41% 48C P0 36W / 170W | 242MiB / 12288MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 3600 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:38310016 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:512 btt_init:mup compile_structured_mlp:False +rate_lambda:0.0002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:0 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/0 val_loss:6.9408 val_bpb:4.1038 train_time:0ms step_avg:0.04ms +peak memory allocated: 457 MiB reserved: 502 MiB +Serialized model: 152261467 bytes +Code size: 68319 bytes +Total submission size: 152329786 bytes +Serialized model int5_odd+zlib: 5808009 bytes (payload:12991211 raw_torch:13062267 payload_ratio:11.71x) +Total submission size int5_odd+zlib: 5876328 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9463 val_bpb:4.1071 eval_time:87220ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94626862 val_bpb:4.10705407 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/scout_R512_L20.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/scout_R512_L20.txt new file mode 100644 index 000000000..0dccba049 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/scout_R512_L20.txt @@ -0,0 +1,1668 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 262_144)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 0)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + ) + if compile_structured_mlp: + self.fc = torch.compile(self.fc, mode="reduce-overhead", fullgraph=False, dynamic=False) + self.proj = torch.compile(self.proj, mode="reduce-overhead", fullgraph=False, dynamic=False) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 12:00:20 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 42% 49C P2 40W / 170W | 242MiB / 12288MiB | 3% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 985 G /usr/lib/xorg/Xorg 117MiB | +| 0 N/A N/A 3705 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:65536 +val_token_limit:65536 +model_params:47756448 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:512 btt_init:mup compile_structured_mlp:False +rate_lambda:0.0002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:4096 train_seq_len:256 iterations:0 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/0 val_loss:6.9399 val_bpb:4.1033 train_time:0ms step_avg:0.03ms +peak memory allocated: 526 MiB reserved: 570 MiB +Serialized model: 190064443 bytes +Code size: 68319 bytes +Total submission size: 190132762 bytes +Serialized model int5_odd+zlib: 7228717 bytes (payload:16194811 raw_torch:16283675 payload_ratio:11.73x) +Total submission size int5_odd+zlib: 7297036 bytes +final_int5_odd_zlib_roundtrip val_loss:6.9432 val_bpb:4.1052 eval_time:109429ms +final_int5_odd_zlib_roundtrip_exact val_loss:6.94316441 val_bpb:4.10521867 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/smoke_btt_40.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/smoke_btt_40.txt new file mode 100644 index 000000000..8b7946704 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/smoke_btt_40.txt @@ -0,0 +1,4822 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 0)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "1"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.02)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0005)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.2)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.6)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "0"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if t.ndim != 2 or t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor_int5_odd(t) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + out[name] = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if p.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return True + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(m1, self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + std = 1.0 / math.sqrt(self.in_features) + nn.init.normal_(self.core1, mean=0.0, std=std / math.sqrt(self.rank)) + nn.init.normal_(self.core2, mean=0.0, std=std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + tmp = torch.einsum("buv,umr->bmrv", x2, core1) + y = torch.einsum("bmrv,mrwn->bnw", tmp, core2) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, structured_mlp, structured_linear_type, btt_rank, btt_in_shape, btt_out_shape) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 02:50:44 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 37C P8 19W / 170W | 214MiB / 12288MiB | 42% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 975 G /usr/lib/xorg/Xorg 89MiB | +| 0 N/A N/A 4154 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:22073416 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp +rate_lambda:0.002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 0)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "1"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.02)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0005)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.2)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.6)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "0"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if t.ndim != 2 or t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor_int5_odd(t) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + out[name] = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if p.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return True + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + std = 1.0 / math.sqrt(self.in_features) + nn.init.normal_(self.core1, mean=0.0, std=std / math.sqrt(self.rank)) + nn.init.normal_(self.core2, mean=0.0, std=std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + tmp = torch.einsum("buv,umr->bmrv", x2, core1) + y = torch.einsum("bmrv,rvn->bmn", tmp, core2) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, structured_mlp, structured_linear_type, btt_rank, btt_in_shape, btt_out_shape) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 02:51:23 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 34C P5 21W / 170W | 214MiB / 12288MiB | 36% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 975 G /usr/lib/xorg/Xorg 89MiB | +| 0 N/A N/A 4182 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:8507464 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp +rate_lambda:0.002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/40 val_loss:6.9337 val_bpb:4.1543 train_time:0ms step_avg:0.02ms +step:1/40 train_loss:6.9388 ce_loss:6.9388 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:1000ms step_avg:1000.05ms +step:2/40 train_loss:12.9680 ce_loss:12.9680 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:1915ms step_avg:957.62ms +step:3/40 train_loss:10.5748 ce_loss:10.5748 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:2823ms step_avg:941.16ms +step:4/40 train_loss:7.8792 ce_loss:7.8792 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:3749ms step_avg:937.18ms +step:5/40 train_loss:6.6087 ce_loss:6.6087 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:4647ms step_avg:929.40ms +step:6/40 train_loss:6.2112 ce_loss:6.2112 rate_mul:0.083 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:5574ms step_avg:928.97ms +step:7/40 train_loss:6.2583 ce_loss:6.2583 rate_mul:0.167 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:6463ms step_avg:923.32ms +step:8/40 train_loss:6.3075 ce_loss:6.3075 rate_mul:0.250 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:7403ms step_avg:925.40ms +step:9/40 train_loss:6.1007 ce_loss:6.1007 rate_mul:0.333 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:8323ms step_avg:924.83ms +step:10/40 train_loss:6.1337 ce_loss:6.1337 rate_mul:0.417 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:9275ms step_avg:927.52ms +step:20/40 train_loss:5.6381 ce_loss:5.6381 rate_mul:1.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:18640ms step_avg:932.01ms +step:30/40 train_loss:5.5897 ce_loss:5.5897 rate_mul:1.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:28017ms step_avg:933.88ms +step:40/40 train_loss:5.3881 ce_loss:5.3881 rate_mul:1.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:37395ms step_avg:934.87ms +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 0)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "1"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.02)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0005)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.2)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.6)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "0"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if t.ndim != 2 or t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor_int5_odd(t) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + out[name] = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + std = 1.0 / math.sqrt(self.in_features) + nn.init.normal_(self.core1, mean=0.0, std=std / math.sqrt(self.rank)) + nn.init.normal_(self.core2, mean=0.0, std=std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for r in range(self.rank): + left = torch.einsum("buv,um->bmv", x2, core1[:, :, r]) + y = y + torch.einsum("bmv,vn->bmn", left, core2[r]) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, structured_mlp, structured_linear_type, btt_rank, btt_in_shape, btt_out_shape) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 02:53:38 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 33C P5 19W / 170W | 214MiB / 12288MiB | 5% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 975 G /usr/lib/xorg/Xorg 89MiB | +| 0 N/A N/A 4230 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:8507464 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp +rate_lambda:0.002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/40 val_loss:6.9336 val_bpb:4.1543 train_time:0ms step_avg:0.03ms +step:1/40 train_loss:6.9388 ce_loss:6.9388 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:4609ms step_avg:4608.77ms +step:2/40 train_loss:12.9960 ce_loss:12.9960 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:9250ms step_avg:4624.80ms +step:3/40 train_loss:12.4420 ce_loss:12.4420 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:13919ms step_avg:4639.74ms +step:4/40 train_loss:11.5348 ce_loss:11.5348 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:18580ms step_avg:4644.88ms +step:5/40 train_loss:10.6759 ce_loss:10.6759 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:23232ms step_avg:4646.30ms +step:6/40 train_loss:9.6460 ce_loss:9.6457 rate_mul:0.083 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:28381ms step_avg:4730.19ms +step:7/40 train_loss:8.7863 ce_loss:8.7858 rate_mul:0.167 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:33433ms step_avg:4776.16ms +step:8/40 train_loss:8.0623 ce_loss:8.0616 rate_mul:0.250 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:38479ms step_avg:4809.85ms +step:9/40 train_loss:7.2233 ce_loss:7.2223 rate_mul:0.333 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:43525ms step_avg:4836.11ms +step:10/40 train_loss:6.8557 ce_loss:6.8545 rate_mul:0.417 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:48605ms step_avg:4860.50ms +step:20/40 train_loss:5.6887 ce_loss:5.6857 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:99262ms step_avg:4963.11ms +step:30/40 train_loss:5.6327 ce_loss:5.6297 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:149610ms step_avg:4987.01ms +step:40/40 train_loss:5.4766 ce_loss:5.4736 rate_mul:1.000 rate_proxy_bits:1.4980 scale_proxy:0.013530 est_model_bits:1325359 train_time:199784ms step_avg:4994.61ms +step:40/40 val_loss:5.4830 val_bpb:3.2851 train_time:199785ms step_avg:4994.62ms +peak memory allocated: 4044 MiB reserved: 4090 MiB +Serialized model: 33020979 bytes +Code size: 65623 bytes +Total submission size: 33086602 bytes +Serialized model int5_odd+zlib: 2922261 bytes (payload:4415453 raw_torch:4451491 payload_ratio:7.47x) +Total submission size int5_odd+zlib: 2987884 bytes +final_int5_odd_zlib_roundtrip val_loss:5.6865 val_bpb:3.4071 eval_time:52555ms +final_int5_odd_zlib_roundtrip_exact val_loss:5.68650442 val_bpb:3.40705054 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/smoke_dense_40.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/smoke_dense_40.txt new file mode 100644 index 000000000..4690c497b --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/logs/smoke_dense_40.txt @@ -0,0 +1,3209 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 0)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "1"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.02)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0005)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.2)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.6)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "0"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if t.ndim != 2 or t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor_int5_odd(t) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + out[name] = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if p.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return True + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(m1, self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + std = 1.0 / math.sqrt(self.in_features) + nn.init.normal_(self.core1, mean=0.0, std=std / math.sqrt(self.rank)) + nn.init.normal_(self.core2, mean=0.0, std=std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + tmp = torch.einsum("buv,umr->bmrv", x2, core1) + y = torch.einsum("bmrv,mrwn->bnw", tmp, core2) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, structured_mlp, structured_linear_type, btt_rank, btt_in_shape, btt_out_shape) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 02:49:12 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 36C P8 19W / 170W | 214MiB / 12288MiB | 8% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 975 G /usr/lib/xorg/Xorg 89MiB | +| 0 N/A N/A 4107 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:17059912 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:False structured_linear_type:btt_mlp +rate_lambda:0.002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/40 val_loss:6.9354 val_bpb:4.1553 train_time:0ms step_avg:0.02ms +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 0)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "1"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.02)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0005)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.2)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.6)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "0"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if t.ndim != 2 or t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor_int5_odd(t) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + out[name] = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if p.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return True + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(m1, self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + std = 1.0 / math.sqrt(self.in_features) + nn.init.normal_(self.core1, mean=0.0, std=std / math.sqrt(self.rank)) + nn.init.normal_(self.core2, mean=0.0, std=std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + x_dtype = x.dtype + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + core1 = self.core1.to(dtype=x_dtype) + core2 = self.core2.to(dtype=x_dtype) + tmp = torch.einsum("buv,umr->bmrv", x2, core1) + y = torch.einsum("bmrv,mrwn->bnw", tmp, core2) + y = y.reshape(*orig_shape, self.out_features) + if self.bias is not None: + y = y + self.bias.to(dtype=y.dtype) + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, structured_mlp, structured_linear_type, btt_rank, btt_in_shape, btt_out_shape) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_schedule(step: int) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + if args.iterations <= 0: + return 1.0 + frac = step / max(args.iterations, 1) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.10.12 (main, Mar 3 2026, 11:56:32) [GCC 11.4.0] +Running PyTorch 2.11.0+cu130 +Tue Mar 31 02:49:53 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 590.48.01 Driver Version: 590.48.01 CUDA Version: 13.1 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | +| 30% 34C P5 20W / 170W | 214MiB / 12288MiB | 36% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 975 G /usr/lib/xorg/Xorg 89MiB | +| 0 N/A N/A 4130 C python3 104MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/home/hz/Desktop/parameter-golf-main/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:1 +val_loader:shards pattern=/home/hz/Desktop/parameter-golf-main/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:17059912 +world_size:1 grad_accum_steps:8 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:False structured_linear_type:btt_mlp +rate_lambda:0.002 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:16384 train_seq_len:256 iterations:40 warmup_steps:0 max_wallclock_seconds:600.000 +seed:1337 +step:0/40 val_loss:6.9354 val_bpb:4.1553 train_time:0ms step_avg:0.02ms +step:1/40 train_loss:6.9416 ce_loss:6.9416 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:528ms step_avg:527.52ms +step:2/40 train_loss:12.9378 ce_loss:12.9378 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:944ms step_avg:472.13ms +step:3/40 train_loss:9.3795 ce_loss:9.3795 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:1358ms step_avg:452.77ms +step:4/40 train_loss:7.0969 ce_loss:7.0969 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:1766ms step_avg:441.44ms +step:5/40 train_loss:6.7238 ce_loss:6.7238 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:2178ms step_avg:435.64ms +step:6/40 train_loss:6.7849 ce_loss:6.7845 rate_mul:0.083 rate_proxy_bits:2.2322 scale_proxy:0.012163 est_model_bits:21065706 train_time:2855ms step_avg:475.79ms +step:7/40 train_loss:6.6324 ce_loss:6.6317 rate_mul:0.167 rate_proxy_bits:2.2312 scale_proxy:0.012415 est_model_bits:21055872 train_time:3475ms step_avg:496.49ms +step:8/40 train_loss:6.5181 ce_loss:6.5170 rate_mul:0.250 rate_proxy_bits:2.2301 scale_proxy:0.012661 est_model_bits:21045870 train_time:4093ms step_avg:511.59ms +step:9/40 train_loss:6.3127 ce_loss:6.3113 rate_mul:0.333 rate_proxy_bits:2.2291 scale_proxy:0.012896 est_model_bits:21036376 train_time:4709ms step_avg:523.24ms +step:10/40 train_loss:6.2998 ce_loss:6.2979 rate_mul:0.417 rate_proxy_bits:2.2281 scale_proxy:0.013118 est_model_bits:21026974 train_time:5334ms step_avg:533.37ms +step:20/40 train_loss:5.3163 ce_loss:5.3118 rate_mul:1.000 rate_proxy_bits:2.2187 scale_proxy:0.014803 est_model_bits:20937854 train_time:11510ms step_avg:575.51ms +step:30/40 train_loss:4.8916 ce_loss:4.8872 rate_mul:1.000 rate_proxy_bits:2.2128 scale_proxy:0.015605 est_model_bits:20882642 train_time:17688ms step_avg:589.61ms +step:40/40 train_loss:4.5637 ce_loss:4.5593 rate_mul:1.000 rate_proxy_bits:2.2093 scale_proxy:0.016088 est_model_bits:20849622 train_time:23804ms step_avg:595.09ms +step:40/40 val_loss:4.5863 val_bpb:2.7479 train_time:23804ms step_avg:595.10ms +peak memory allocated: 1298 MiB reserved: 2002 MiB +Serialized model: 67224983 bytes +Code size: 65509 bytes +Total submission size: 67290492 bytes +Serialized model int5_odd+zlib: 3781677 bytes (payload:5819363 raw_torch:5850255 payload_ratio:11.55x) +Total submission size int5_odd+zlib: 3847186 bytes +final_int5_odd_zlib_roundtrip val_loss:5.2035 val_bpb:3.1177 eval_time:4547ms +final_int5_odd_zlib_roundtrip_exact val_loss:5.20348889 val_bpb:3.11765336 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/requirements.txt b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/requirements.txt new file mode 100644 index 000000000..e2690ed15 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/requirements.txt @@ -0,0 +1,7 @@ +numpy +sentencepiece +torch +datasets +huggingface-hub +tqdm +zstandard diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_8xh100.sh b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_8xh100.sh new file mode 100644 index 000000000..f25b7d025 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_8xh100.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Single-node 8xH100 launch template. +# Override any variable below from the shell as needed. + +: "${NPROC_PER_NODE:=8}" +: "${RUN_ID:=cloud_8xh100}" +: "${TRAIN_BATCH_TOKENS:=131072}" +: "${TRAIN_MICROBATCH_TOKENS:=8192}" +: "${VAL_BATCH_SIZE:=1048576}" +: "${VAL_TOKEN_LIMIT:=0}" +: "${LR_SCALE_METHOD:=sqrt}" +: "${LR_REFERENCE_BATCH_TOKENS:=16384}" +: "${WARMUP_SCALE_METHOD:=linear}" +: "${WARMUP_REFERENCE_BATCH_TOKENS:=16384}" +: "${WARMUP_STEPS:=1}" +: "${COMPILE_STRUCTURED_MLP:=1}" +: "${ENABLE_COMPILE:=0}" + +export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}" +export TORCH_NCCL_ASYNC_ERROR_HANDLING="${TORCH_NCCL_ASYNC_ERROR_HANDLING:-1}" +export NCCL_DEBUG="${NCCL_DEBUG:-WARN}" + +torchrun \ + --standalone \ + --nproc_per_node="${NPROC_PER_NODE}" \ + train_gpt.py diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_capacity_scout.sh b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_capacity_scout.sh new file mode 100644 index 000000000..5341119ac --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_capacity_scout.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash +set -euo pipefail + +configs=( + "R128_L12 BTT_RANK=128 NUM_LAYERS=12" + "R256_L16 BTT_RANK=256 NUM_LAYERS=16" + "R512_L20 BTT_RANK=512 NUM_LAYERS=20" + "R1024_L20 BTT_RANK=1024 NUM_LAYERS=20" + "R1024_L24 BTT_RANK=1024 NUM_LAYERS=24" +) + +for cfg in "${configs[@]}"; do + set -- $cfg + name=$1 + shift + echo "=== $name ===" + env \ + PYTHONUNBUFFERED=1 \ + RUN_ID="scout_${name}" \ + ITERATIONS=0 \ + COMPILE_STRUCTURED_MLP=0 \ + ENABLE_COMPILE=0 \ + VAL_TOKEN_LIMIT=2048 \ + VAL_BATCH_SIZE=2048 \ + TRAIN_BATCH_TOKENS=4096 \ + TRAIN_SEQ_LEN=256 \ + RATE_LAMBDA=0.0002 \ + SCALE_LAMBDA=0.0002 \ + "$@" \ + python3 train_gpt.py +done diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_compile_bench.sh b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_compile_bench.sh new file mode 100644 index 000000000..450f91d00 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_compile_bench.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -euo pipefail + +for compile_flag in 0 1; do + if [[ "$compile_flag" == "1" ]]; then + name="compile_on_bench" + else + name="compile_off_bench" + fi + echo "=== $name ===" + env \ + RUN_ID="$name" \ + ITERATIONS=8 \ + WARMUP_STEPS=1 \ + VAL_LOSS_EVERY=0 \ + COMPILE_STRUCTURED_MLP="$compile_flag" \ + ENABLE_COMPILE=0 \ + VAL_TOKEN_LIMIT=65536 \ + VAL_BATCH_SIZE=65536 \ + TRAIN_BATCH_TOKENS=4096 \ + TRAIN_SEQ_LEN=256 \ + RATE_LAMBDA=0.00002 \ + SCALE_LAMBDA=0.0002 \ + python3 train_gpt.py +done diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_extreme_accum.sh b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_extreme_accum.sh new file mode 100644 index 000000000..f5c38860d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_extreme_accum.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +set -euo pipefail + +env \ + RUN_ID="${RUN_ID:-extreme_accum}" \ + ITERATIONS="${ITERATIONS:-1}" \ + WARMUP_STEPS="${WARMUP_STEPS:-0}" \ + TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-262144}" \ + TRAIN_MICROBATCH_TOKENS="${TRAIN_MICROBATCH_TOKENS:-4096}" \ + VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-0}" \ + VAL_TOKEN_LIMIT="${VAL_TOKEN_LIMIT:-2048}" \ + VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-2048}" \ + python3 train_gpt.py diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_init_ablation.sh b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_init_ablation.sh new file mode 100644 index 000000000..34bf7c461 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_init_ablation.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail + +for init in mup xavier; do + echo "=== BTT_INIT=$init ===" + env \ + RUN_ID="init_${init}_bench" \ + ITERATIONS=12 \ + WARMUP_STEPS=1 \ + VAL_LOSS_EVERY=0 \ + COMPILE_STRUCTURED_MLP=1 \ + ENABLE_COMPILE=0 \ + BTT_INIT="$init" \ + VAL_TOKEN_LIMIT=65536 \ + VAL_BATCH_SIZE=65536 \ + TRAIN_BATCH_TOKENS=4096 \ + TRAIN_SEQ_LEN=256 \ + RATE_LAMBDA=0.00002 \ + SCALE_LAMBDA=0.0002 \ + python3 train_gpt.py +done diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_lambda_sweep.sh b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_lambda_sweep.sh new file mode 100644 index 000000000..fa5224446 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_lambda_sweep.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -euo pipefail + +lambdas=(0.002 0.0002 0.00002) + +for lambda in "${lambdas[@]}"; do + name=${lambda//./p} + echo "=== RATE_LAMBDA=$lambda ===" + env \ + RUN_ID="lambda_${name}" \ + ITERATIONS=12 \ + WARMUP_STEPS=1 \ + COMPILE_STRUCTURED_MLP=1 \ + ENABLE_COMPILE=0 \ + VAL_TOKEN_LIMIT=65536 \ + VAL_BATCH_SIZE=65536 \ + TRAIN_BATCH_TOKENS=4096 \ + TRAIN_SEQ_LEN=256 \ + RATE_LAMBDA="$lambda" \ + SCALE_LAMBDA=0.0002 \ + python3 train_gpt.py +done diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_local_tuning.sh b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_local_tuning.sh new file mode 100644 index 000000000..7c37cb489 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_local_tuning.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +set -euo pipefail + +bash run_compile_bench.sh +bash run_init_ablation.sh +bash run_capacity_scout.sh +bash run_lambda_sweep.sh diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_mock_8xh100_math.sh b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_mock_8xh100_math.sh new file mode 100644 index 000000000..a563b374b --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/run_mock_8xh100_math.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail + +env \ + RUN_ID="${RUN_ID:-mock_8xh100_math}" \ + DRY_RUN_INIT_ONLY=1 \ + SIMULATED_WORLD_SIZE="${SIMULATED_WORLD_SIZE:-8}" \ + DEBUG_DATA_SHARDING_STEPS="${DEBUG_DATA_SHARDING_STEPS:-1}" \ + NUM_LAYERS="${NUM_LAYERS:-24}" \ + BTT_RANK="${BTT_RANK:-1024}" \ + TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-131072}" \ + TRAIN_MICROBATCH_TOKENS="${TRAIN_MICROBATCH_TOKENS:-8192}" \ + LR_SCALE_METHOD="${LR_SCALE_METHOD:-sqrt}" \ + LR_REFERENCE_BATCH_TOKENS="${LR_REFERENCE_BATCH_TOKENS:-16384}" \ + WARMUP_SCALE_METHOD="${WARMUP_SCALE_METHOD:-linear}" \ + WARMUP_REFERENCE_BATCH_TOKENS="${WARMUP_REFERENCE_BATCH_TOKENS:-16384}" \ + WARMUP_STEPS="${WARMUP_STEPS:-1}" \ + VAL_TOKEN_LIMIT="${VAL_TOKEN_LIMIT:-2048}" \ + VAL_BATCH_SIZE="${VAL_BATCH_SIZE:-2048}" \ + python3 train_gpt.py diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/submission.json b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/submission.json new file mode 100644 index 000000000..0c4bcfe53 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/submission.json @@ -0,0 +1,20 @@ +{ + "author": "", + "github_id": "", + "name": "Entropy-Aware Int5-Odd + BTT-MLP (1xH100 materialized eval)", + "blurb": "Non-record exploratory submission combining an entropy-aware 5-bin odd quantization objective with a TT/BTT-inspired structured MLP and an evaluation-time materialization path that contracts the structured MLP into dense weights for fast validation. This package is based on the best fully recovered 1xH100 run with a 1,048,576-token validation cap, not a full-validation leaderboard run.", + "date": "2026-03-31T00:00:00Z", + "track": "non-record-16mb", + "val_loss": 9.82734489, + "val_bpb": 5.88802157, + "pre_quant_val_loss": 8.9222, + "pre_quant_val_bpb": 5.3457, + "step_stop": 31, + "wallclock_seconds": 603.522, + "bytes_total": 5267667, + "bytes_code": 83124, + "bytes_model_int8_zlib": null, + "bytes_model_int5_odd_zlib": 5184543, + "gpu": "1xH100 80GB HBM3", + "notes": "This non-record package is anchored to the recovered H100 run copied into train.log. It uses 80 FineWeb train shards and VAL_TOKEN_LIMIT=1048576, so the reported val_bpb is not directly comparable to full-validation challenge runs. The materialized eval path reduced quantized validation on the 1,048,576-token cap to 854 ms. A later 3300s H100 run advanced further but was interrupted by RunPod credit exhaustion before the final artifact and complete log could be recovered, so it is intentionally excluded from these metrics." +} diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/train.log b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/train.log new file mode 100644 index 000000000..2d659424c --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/train.log @@ -0,0 +1,1923 @@ +logs/h100_real_r256_l16_seq1024_mb2048_materialized.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:20000 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:0/20000 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.03ms +step:1/20000 train_loss:6.9404 ce_loss:6.9404 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:18556ms step_avg:18556.11ms +step:2/20000 train_loss:12.7072 ce_loss:12.7072 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:37651ms step_avg:18825.71ms +step:3/20000 train_loss:12.5900 ce_loss:12.5900 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:56956ms step_avg:18985.42ms +step:4/20000 train_loss:12.3686 ce_loss:12.3686 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:76316ms step_avg:19079.05ms +step:5/20000 train_loss:12.1946 ce_loss:12.1946 rate_mul:0.097 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:96008ms step_avg:19201.52ms +step:6/20000 train_loss:11.8978 ce_loss:11.8978 rate_mul:0.204 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:115743ms step_avg:19290.47ms +step:7/20000 train_loss:11.7299 ce_loss:11.7299 rate_mul:0.312 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:135345ms step_avg:19334.94ms +step:8/20000 train_loss:11.5707 ce_loss:11.5707 rate_mul:0.419 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:154822ms step_avg:19352.71ms +step:9/20000 train_loss:11.1523 ce_loss:11.1523 rate_mul:0.527 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:174289ms step_avg:19365.47ms +step:10/20000 train_loss:11.0726 ce_loss:11.0726 rate_mul:0.667 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:194115ms step_avg:19411.48ms +step:20/20000 train_loss:9.4721 ce_loss:9.4721 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:390448ms step_avg:19522.42ms +step:30/20000 train_loss:8.9153 ce_loss:8.9152 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:584228ms step_avg:19474.25ms +step:31/20000 val_loss:8.9222 val_bpb:5.3457 train_time:603522ms step_avg:19468.45ms +stopping_early: wallclock_cap train_time:603522ms step:31/20000 +peak memory allocated: 50290 MiB reserved: 78558 MiB +Serialized model: 101929819 bytes +Code size: 83124 bytes +Total submission size: 102012943 bytes +Serialized model int5_odd+zlib: 5184543 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 5267667 bytes +final_int5_odd_zlib_roundtrip val_loss:9.8273 val_bpb:5.8880 eval_time:854ms +final_int5_odd_zlib_roundtrip_exact val_loss:9.82734489 val_bpb:5.88802157 +", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_eval_chunk_ranks = int(os.environ.get("STRUCTURED_EVAL_CHUNK_RANKS", 1)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + eval_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + eval_module = model.module if isinstance(model, DDP) else model + for module in eval_module.modules(): + if isinstance(module, StructuredLinear): + module.materialize(eval_dtype, device) + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias=bias) + self._cached_weight: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int | None] | None = None + + def _cast_params(self, x: Tensor) -> tuple[Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.weight._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_weight, self._cached_bias + weight = self.weight.to(dtype=x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_weight = weight + self._cached_bias = bias + self._cached_key = cache_key + return weight, bias + + def train(self, mode: bool = True): + if mode: + self._cached_weight = None + self._cached_bias = None + self._cached_key = None + return super().train(mode) + + def forward(self, x: Tensor) -> Tensor: + weight, bias = self._cast_params(x) + return F.linear(x, weight, bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + eval_chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.eval_chunk_ranks = max(int(eval_chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + self._cached_core1: Tensor | None = None + self._cached_core2: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + self._dense_eval_weight: Tensor | None = None + self._dense_eval_bias: Tensor | None = None + self._dense_eval_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def _cast_factors(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.core1._version, self.core2._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_core1, self._cached_core2, self._cached_bias + core1 = self.core1.to(dtype=x.dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x.dtype).contiguous() + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_core1 = core1 + self._cached_core2 = core2 + self._cached_bias = bias + self._cached_key = cache_key + return core1, core2, bias + + def materialize(self, dtype: torch.dtype, device: torch.device) -> None: + if self.impl is not None: + return + bias_version = self.bias._version if self.bias is not None else None + cache_key = (dtype, device, self.core1._version, self.core2._version, bias_version) + if self._dense_eval_key == cache_key and self._dense_eval_weight is not None: + return + core1 = self.core1.to(device=device, dtype=dtype) + core2 = self.core2.to(device=device, dtype=dtype) + dense = torch.einsum("umr,rvn->mnuv", core1, core2).reshape(self.out_features, self.in_features).contiguous() + self._dense_eval_weight = dense + self._dense_eval_bias = self.bias.to(device=device, dtype=dtype) if self.bias is not None else None + self._dense_eval_key = cache_key + + def train(self, mode: bool = True): + if mode: + self._cached_core1 = None + self._cached_core2 = None + self._cached_bias = None + self._cached_key = None + self._dense_eval_weight = None + self._dense_eval_bias = None + self._dense_eval_key = None + return super().train(mode) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + if not self.training: + self.materialize(x.dtype, x.device) + return F.linear(x, self._dense_eval_weight, self._dense_eval_bias) + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1, core2, bias = self._cast_factors(x) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_eval_chunk_ranks=args.structured_eval_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_eval_chunk_ranks:{args.structured_eval_chunk_ranks} " + f"structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + model.train() + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Tue Mar 31 19:03:36 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 33C P0 76W / 700W | 527MiB / 81559MiB | 2% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:1048576 +val_token_limit:1048576 +model_params:25727104 +world_size:1 grad_accum_steps:8 train_microbatch_tokens:2048 effective_train_batch_tokens:16384 +enable_compile:False +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +quant_scheme:int5_odd compressor:zlib level:9 structured_mlp:True structured_linear_type:btt_mlp btt_rank:256 btt_init:mup compile_structured_mlp:False structured_chunk_ranks:64 structured_eval_chunk_ranks:1 structured_compile_mode:reduce-overhead +rate_lambda:2e-05 scale_lambda:0.0002 rate_warmup_start_fraction:0.1 rate_full_fraction:0.4 rate_target_tensors:mlp_only rate_schedule_mode:auto max_expected_steps:0 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens_target:16384 effective_train_batch_tokens:16384 train_seq_len:1024 iterations:20000 warmup_steps:1 max_wallclock_seconds:600.000 +lr_scale_method:none lr_reference_batch_tokens:16384 batch_multiplier:1.0000 warmup_scale_method:none warmup_reference_batch_tokens:16384 +seed:1337 +warmup_step:1/1 +step:0/20000 val_loss:6.9379 val_bpb:4.1568 train_time:0ms step_avg:0.03ms +step:1/20000 train_loss:6.9404 ce_loss:6.9404 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:18556ms step_avg:18556.11ms +step:2/20000 train_loss:12.7072 ce_loss:12.7072 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:37651ms step_avg:18825.71ms +step:3/20000 train_loss:12.5900 ce_loss:12.5900 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:56956ms step_avg:18985.42ms +step:4/20000 train_loss:12.3686 ce_loss:12.3686 rate_mul:0.000 rate_proxy_bits:0.0000 scale_proxy:0.000000 est_model_bits:0 train_time:76316ms step_avg:19079.05ms +step:5/20000 train_loss:12.1946 ce_loss:12.1946 rate_mul:0.097 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:96008ms step_avg:19201.52ms +step:6/20000 train_loss:11.8978 ce_loss:11.8978 rate_mul:0.204 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:115743ms step_avg:19290.47ms +step:7/20000 train_loss:11.7299 ce_loss:11.7299 rate_mul:0.312 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:135345ms step_avg:19334.94ms +step:8/20000 train_loss:11.5707 ce_loss:11.5707 rate_mul:0.419 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:154822ms step_avg:19352.71ms +step:9/20000 train_loss:11.1523 ce_loss:11.1523 rate_mul:0.527 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:174289ms step_avg:19365.47ms +step:10/20000 train_loss:11.0726 ce_loss:11.0726 rate_mul:0.667 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:194115ms step_avg:19411.48ms +step:20/20000 train_loss:9.4721 ce_loss:9.4721 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:390448ms step_avg:19522.42ms +step:30/20000 train_loss:8.9153 ce_loss:8.9152 rate_mul:1.000 rate_proxy_bits:1.4982 scale_proxy:0.059731 est_model_bits:18851774 train_time:584228ms step_avg:19474.25ms +step:31/20000 val_loss:8.9222 val_bpb:5.3457 train_time:603522ms step_avg:19468.45ms +stopping_early: wallclock_cap train_time:603522ms step:31/20000 +peak memory allocated: 50290 MiB reserved: 78558 MiB +Serialized model: 101929819 bytes +Code size: 83124 bytes +Total submission size: 102012943 bytes +Serialized model int5_odd+zlib: 5184543 bytes (payload:8780523 raw_torch:8851579 payload_ratio:11.60x) +Total submission size int5_odd+zlib: 5267667 bytes +final_int5_odd_zlib_roundtrip val_loss:9.8273 val_bpb:5.8880 eval_time:854ms +final_int5_odd_zlib_roundtrip_exact val_loss:9.82734489 val_bpb:5.88802157 diff --git a/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/train_gpt.py b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/train_gpt.py new file mode 100644 index 000000000..dd62eb5ca --- /dev/null +++ b/records/track_non_record_16mb/2026-03-31_EntropyAware_Int5Odd_BTTMLP_3060/train_gpt.py @@ -0,0 +1,1913 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + + _HAS_ZSTD = True +except ImportError: + zstd = None + _HAS_ZSTD = False + +REPO_ROOT = Path(__file__).resolve().parents[3] + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", str(REPO_ROOT / "data/datasets/fineweb10B_sp1024")) + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", str(REPO_ROOT / "data/tokenizers/fineweb_1024_bpe.model")) + run_id = os.environ.get("RUN_ID", "train") + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 32_768)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 10)) + val_token_limit = int(os.environ.get("VAL_TOKEN_LIMIT", 1_048_576)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 40)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 1)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 16_384)) + train_microbatch_tokens = int(os.environ.get("TRAIN_MICROBATCH_TOKENS", 0)) + grad_accum_steps = int(os.environ.get("GRAD_ACCUM_STEPS", 0)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 256)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + enable_compile = bool(int(os.environ.get("ENABLE_COMPILE", "0"))) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + lr_scale_method = os.environ.get("LR_SCALE_METHOD", "none") + lr_reference_batch_tokens = int(os.environ.get("LR_REFERENCE_BATCH_TOKENS", 16_384)) + warmup_scale_method = os.environ.get("WARMUP_SCALE_METHOD", "none") + warmup_reference_batch_tokens = int(os.environ.get("WARMUP_REFERENCE_BATCH_TOKENS", 16_384)) + simulated_world_size = int(os.environ.get("SIMULATED_WORLD_SIZE", 0)) + simulated_rank = int(os.environ.get("SIMULATED_RANK", 0)) + dry_run_init_only = bool(int(os.environ.get("DRY_RUN_INIT_ONLY", "0"))) + debug_data_sharding_steps = int(os.environ.get("DEBUG_DATA_SHARDING_STEPS", 0)) + + # Export / compression. + quant_scheme = os.environ.get("QUANT_SCHEME", "int5_odd") + compressor = os.environ.get("COMPRESSOR", "zlib") + compressor_level = int(os.environ.get("COMPRESSOR_LEVEL", 9)) + + # Entropy-aware rate regularization. + rate_lambda = float(os.environ.get("RATE_LAMBDA", 0.00002)) + scale_lambda = float(os.environ.get("SCALE_LAMBDA", 0.0002)) + rate_temperature = float(os.environ.get("RATE_TEMPERATURE", 6.0)) + rate_warmup_start_fraction = float(os.environ.get("RATE_WARMUP_START_FRACTION", 0.1)) + rate_full_fraction = float(os.environ.get("RATE_FULL_FRACTION", 0.4)) + rate_target_tensors = os.environ.get("RATE_TARGET_TENSORS", "mlp_only") + rate_schedule_mode = os.environ.get("RATE_SCHEDULE_MODE", "auto") + max_expected_steps = int(os.environ.get("MAX_EXPECTED_STEPS", 0)) + + # Structured MLP. + structured_mlp = bool(int(os.environ.get("STRUCTURED_MLP", "1"))) + structured_linear_type = os.environ.get("STRUCTURED_LINEAR_TYPE", "btt_mlp") + btt_rank = int(os.environ.get("BTT_RANK", 32)) + btt_in_shape = os.environ.get("BTT_IN_SHAPE", "") + btt_out_shape = os.environ.get("BTT_OUT_SHAPE", "") + btt_init = os.environ.get("BTT_INIT", "mup") + btt_init_gain = float(os.environ.get("BTT_INIT_GAIN", 1.0)) + compile_structured_mlp = bool(int(os.environ.get("COMPILE_STRUCTURED_MLP", "1"))) + structured_chunk_ranks = int(os.environ.get("STRUCTURED_CHUNK_RANKS", 64)) + structured_eval_chunk_ranks = int(os.environ.get("STRUCTURED_EVAL_CHUNK_RANKS", 1)) + structured_compile_mode = os.environ.get("STRUCTURED_COMPILE_MODE", "reduce-overhead") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // world_size + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + eval_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + eval_module = model.module if isinstance(model, DDP) else model + for module in eval_module.modules(): + if isinstance(module, StructuredLinear): + module.materialize(eval_dtype, device) + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +INT5_ODD_LEVELS = (-2, -1, 0, 1, 2) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + + +def maybe_mark_compile_step_begin(enabled: bool) -> None: + if not enabled: + return + mark_step = getattr(torch.compiler, "cudagraph_mark_step_begin", None) + if mark_step is not None: + mark_step() + + +def scale_from_batch_multiplier(value: float, multiplier: float, method: str) -> float: + if method == "none": + return value + if method == "sqrt": + return value * math.sqrt(multiplier) + if method == "linear": + return value * multiplier + raise ValueError(f"Unsupported scaling method: {method}") + + +def scaled_warmup_steps(base_steps: int, multiplier: float, method: str) -> int: + if base_steps <= 0: + return 0 + scaled = scale_from_batch_multiplier(float(base_steps), multiplier, method) + return max(int(math.ceil(scaled)), 1) + + +def parse_factor_shape(spec: str, dim: int) -> tuple[int, int]: + if spec: + vals = tuple(int(v.strip()) for v in spec.split(",") if v.strip()) + if len(vals) != 2 or vals[0] * vals[1] != dim: + raise ValueError(f"Invalid factor shape '{spec}' for dim={dim}") + return vals + for a in range(int(math.isqrt(dim)), 0, -1): + if dim % a == 0: + return a, dim // a + return 1, dim + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def pack_base5_codes(q: Tensor) -> tuple[bytes, int]: + vals = q.detach().to("cpu", dtype=torch.uint8).reshape(-1).numpy() + n_vals = int(vals.size) + pad = (3 - (n_vals % 3)) % 3 + if pad: + vals = np.pad(vals, (0, pad)) + groups = vals.reshape(-1, 3).astype(np.uint8) + packed = groups[:, 0] + 5 * groups[:, 1] + 25 * groups[:, 2] + return packed.tobytes(), n_vals + + +def unpack_base5_codes(data: bytes, n_vals: int) -> Tensor: + packed = np.frombuffer(data, dtype=np.uint8).astype(np.int16) + out = np.zeros((packed.size, 3), dtype=np.uint8) + for i in range(3): + out[:, i] = packed % 5 + packed //= 5 + flat = out.reshape(-1)[:n_vals] + return torch.from_numpy(flat.astype(np.uint8, copy=False)) + + +def quantize_float_tensor_int5_odd(t: Tensor) -> tuple[Tensor, Tensor]: + if t.ndim != 2: + raise ValueError("int5_odd quantization only supports 2D tensors") + t32 = t.float() + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clip_abs = clip_abs.clamp_min(1e-6) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 2.0).clamp_min(1.0 / 2.0e6) + q = torch.clamp(torch.round(clipped / scale[:, None]), -2, 2).to(torch.int8).contiguous() + return (q + 2).to(torch.uint8).contiguous(), scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def quantize_state_dict_int5_odd(state_dict: dict[str, Tensor]): + packed: dict[str, bytes] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + shapes: dict[str, tuple[int, ...]] = {} + orig_shapes: dict[str, tuple[int, ...]] = {} + n_codes: dict[str, int] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + quantize_this = t.ndim >= 2 and (t.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name) + if not quantize_this: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + t2 = t.reshape(t.shape[0], -1) if t.ndim > 2 else t + q, s = quantize_float_tensor_int5_odd(t2) + packed_bytes, count = pack_base5_codes(q) + packed[name] = packed_bytes + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + shapes[name] = tuple(int(v) for v in t2.shape) + orig_shapes[name] = tuple(int(v) for v in t.shape) + n_codes[name] = count + stats["int8_payload_bytes"] += len(packed_bytes) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int5_odd_clean_per_row_v1", + "packed": packed, + "scales": scales, + "dtypes": dtypes, + "shapes": shapes, + "orig_shapes": orig_shapes, + "n_codes": n_codes, + "passthrough": passthrough, + } + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def dequantize_state_dict_int5_odd(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, packed in obj["packed"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + shape = tuple(int(v) for v in obj["shapes"][name]) + orig_shape = tuple(int(v) for v in obj.get("orig_shapes", {}).get(name, shape)) + q = unpack_base5_codes(packed, int(obj["n_codes"][name])).to(torch.float32).reshape(shape) - 2.0 + s = obj["scales"][name].to(dtype=torch.float32) + deq = (q * s.view(shape[0], *([1] * (len(shape) - 1)))).to(dtype=dtype) + out[name] = deq.reshape(orig_shape).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +def quantize_state_dict(state_dict: dict[str, Tensor], quant_scheme: str): + if quant_scheme == "int8": + return quantize_state_dict_int8(state_dict) + if quant_scheme == "int5_odd": + return quantize_state_dict_int5_odd(state_dict) + raise ValueError(f"Unsupported QUANT_SCHEME={quant_scheme}") + + +def dequantize_state_dict(obj: dict[str, object]) -> dict[str, Tensor]: + quant_format = obj.get("__quant_format__") + if quant_format == "int8_clean_per_row_v1": + return dequantize_state_dict_int8(obj) + if quant_format == "int5_odd_clean_per_row_v1": + return dequantize_state_dict_int5_odd(obj) + raise ValueError(f"Unsupported quant format {quant_format}") + + +def compress_blob(data: bytes, compressor: str, level: int) -> bytes: + if compressor == "zlib": + return zlib.compress(data, level=max(0, min(level, 9))) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdCompressor(level=level).compress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def decompress_blob(data: bytes, compressor: str) -> bytes: + if compressor == "zlib": + return zlib.decompress(data) + if compressor == "zstd": + if not _HAS_ZSTD: + raise RuntimeError("COMPRESSOR=zstd requires zstandard to be installed") + return zstd.ZstdDecompressor().decompress(data) + raise ValueError(f"Unsupported COMPRESSOR={compressor}") + + +def should_rate_regularize_tensor(name: str, p: Tensor, target: str) -> bool: + if not p.requires_grad or not p.is_floating_point() or p.ndim < 2: + return False + if any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + return False + if target == "mlp_only": + return ".mlp." in name + if target == "all_matrices": + return p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL or ".mlp." in name + raise ValueError(f"Unsupported RATE_TARGET_TENSORS={target}") + + +def estimate_tensor_rate_proxy(t: Tensor, temperature: float) -> tuple[Tensor, Tensor, Tensor]: + t2 = t.float() + if t2.ndim > 2: + t2 = t2.reshape(t2.shape[0], -1) + elif t2.ndim == 1: + t2 = t2.reshape(1, -1) + row_scale = t2.abs().mean(dim=1, keepdim=True).clamp_min(1e-6) + normalized = torch.clamp(t2 / row_scale, -2.5, 2.5) + centers = t2.new_tensor(INT5_ODD_LEVELS) + logits = -temperature * (normalized.unsqueeze(-1) - centers) ** 2 + probs = torch.softmax(logits, dim=-1) + hist = probs.mean(dim=(0, 1)) + entropy_nats = -(hist * (hist + 1e-8).log()).sum() + entropy_bits = entropy_nats / math.log(2.0) + scale_proxy = row_scale.mean() + est_bits = entropy_bits * float(t2.numel()) + return entropy_bits, scale_proxy, est_bits + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +def debug_log_data_sharding( + pattern: str, + world_size: int, + global_tokens: int, + seq_len: int, + grad_accum_steps: int, + steps: int, + log0, +) -> None: + stream = TokenStream(pattern) + local_tokens = global_tokens // (world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + global_cursor = 0 + log0( + f"debug_sharding:world_size:{world_size} global_tokens:{global_tokens} " + f"grad_accum_steps:{grad_accum_steps} local_tokens_per_rank:{local_tokens} per_rank_span:{per_rank_span}" + ) + for step_idx in range(steps): + chunk = stream.take(per_rank_span * world_size) + chunk_start = global_cursor + chunk_end = global_cursor + chunk.numel() + log0(f"debug_sharding_step:{step_idx} shared_chunk_token_range:[{chunk_start},{chunk_end})") + for shard_rank in range(world_size): + start = shard_rank * per_rank_span + end = start + per_rank_span + local = chunk[start:end] + preview = ",".join(str(int(t)) for t in local[: min(6, local.numel())].tolist()) + log0( + f"debug_shard rank:{shard_rank} token_range:[{chunk_start + start},{chunk_start + end}) " + f"preview:{preview}" + ) + global_cursor = chunk_end + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias=bias) + self._cached_weight: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int | None] | None = None + + def _cast_params(self, x: Tensor) -> tuple[Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.weight._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_weight, self._cached_bias + weight = self.weight.to(dtype=x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_weight = weight + self._cached_bias = bias + self._cached_key = cache_key + return weight, bias + + def train(self, mode: bool = True): + if mode: + self._cached_weight = None + self._cached_bias = None + self._cached_key = None + return super().train(mode) + + def forward(self, x: Tensor) -> Tensor: + weight, bias = self._cast_params(x) + return F.linear(x, weight, bias) + + +class StructuredLinear(nn.Module): + # Lightweight TT-style structured matrix used only in the MLP stack. + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_type: str, + btt_rank: int, + in_shape_spec: str, + out_shape_spec: str, + btt_init: str, + btt_init_gain: float, + chunk_ranks: int, + eval_chunk_ranks: int, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear_type = linear_type + self.rank = max(int(btt_rank), 1) + self.chunk_ranks = max(int(chunk_ranks), 1) + self.eval_chunk_ranks = max(int(eval_chunk_ranks), 1) + self.btt_init = btt_init + self.btt_init_gain = btt_init_gain + self._cached_core1: Tensor | None = None + self._cached_core2: Tensor | None = None + self._cached_bias: Tensor | None = None + self._cached_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + self._dense_eval_weight: Tensor | None = None + self._dense_eval_bias: Tensor | None = None + self._dense_eval_key: tuple[torch.dtype, torch.device, int, int, int | None] | None = None + if linear_type == "dense": + self.impl = CastedLinear(in_features, out_features, bias=bias) + self.register_parameter("core1", None) + self.register_parameter("core2", None) + return + if linear_type != "btt_mlp": + raise ValueError(f"Unsupported STRUCTURED_LINEAR_TYPE={linear_type}") + self.impl = None + self.in_shape = parse_factor_shape(in_shape_spec, in_features) + self.out_shape = parse_factor_shape(out_shape_spec, out_features) + n1, n2 = self.in_shape + m1, m2 = self.out_shape + self.core1 = nn.Parameter(torch.empty(n1, m1, self.rank)) + self.core2 = nn.Parameter(torch.empty(self.rank, n2, m2)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.impl is not None: + self.impl.reset_parameters() + return + n1, n2 = self.in_shape + if self.btt_init == "mup": + # muP-inspired scaling: keep the per-rank branch update stable while + # preserving O(1) output variance under the sum over TT ranks. + core1_std = self.btt_init_gain / math.sqrt(max(n1, 1)) + core2_std = self.btt_init_gain / math.sqrt(max(self.rank * n2, 1)) + elif self.btt_init == "xavier": + std = self.btt_init_gain / math.sqrt(max(self.in_features, 1)) + core1_std = std / math.sqrt(self.rank) + core2_std = std + else: + raise ValueError(f"Unsupported BTT_INIT={self.btt_init}") + nn.init.normal_(self.core1, mean=0.0, std=core1_std) + nn.init.normal_(self.core2, mean=0.0, std=core2_std) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def _cast_factors(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor | None]: + bias_version = self.bias._version if self.bias is not None else None + cache_key = (x.dtype, x.device, self.core1._version, self.core2._version, bias_version) + if torch.is_inference_mode_enabled() and self._cached_key == cache_key: + return self._cached_core1, self._cached_core2, self._cached_bias + core1 = self.core1.to(dtype=x.dtype).permute(2, 0, 1).contiguous() + core2 = self.core2.to(dtype=x.dtype).contiguous() + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + if torch.is_inference_mode_enabled(): + self._cached_core1 = core1 + self._cached_core2 = core2 + self._cached_bias = bias + self._cached_key = cache_key + return core1, core2, bias + + def materialize(self, dtype: torch.dtype, device: torch.device) -> None: + if self.impl is not None: + return + bias_version = self.bias._version if self.bias is not None else None + cache_key = (dtype, device, self.core1._version, self.core2._version, bias_version) + if self._dense_eval_key == cache_key and self._dense_eval_weight is not None: + return + core1 = self.core1.to(device=device, dtype=dtype) + core2 = self.core2.to(device=device, dtype=dtype) + dense = torch.einsum("umr,rvn->mnuv", core1, core2).reshape(self.out_features, self.in_features).contiguous() + self._dense_eval_weight = dense + self._dense_eval_bias = self.bias.to(device=device, dtype=dtype) if self.bias is not None else None + self._dense_eval_key = cache_key + + def train(self, mode: bool = True): + if mode: + self._cached_core1 = None + self._cached_core2 = None + self._cached_bias = None + self._cached_key = None + self._dense_eval_weight = None + self._dense_eval_bias = None + self._dense_eval_key = None + return super().train(mode) + + def forward(self, x: Tensor) -> Tensor: + if self.impl is not None: + return self.impl(x) + if not self.training: + self.materialize(x.dtype, x.device) + return F.linear(x, self._dense_eval_weight, self._dense_eval_bias) + orig_shape = x.shape[:-1] + x2 = x.reshape(-1, self.in_shape[0], self.in_shape[1]) + x2_t = x2.transpose(1, 2).contiguous() + core1, core2, bias = self._cast_factors(x) + y = x2.new_zeros((x2.shape[0], self.out_shape[0], self.out_shape[1])) + for start in range(0, self.rank, self.chunk_ranks): + end = min(start + self.chunk_ranks, self.rank) + for offset, core2_r in enumerate(core2[start:end], start=start): + left = torch.matmul(x2_t, core1[offset]) + y = y + torch.matmul(left.transpose(1, 2), core2_r) + y = y.reshape(*orig_shape, self.out_features) + if bias is not None: + y = y + bias + return y + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if not torch.is_inference_mode_enabled(): + cos = cos.clone() + sin = sin.clone() + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__( + self, + dim: int, + mlp_mult: int, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + hidden = mlp_mult * dim + if structured_mlp: + self.fc = StructuredLinear( + dim, + hidden, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_in_shape, + out_shape_spec=btt_out_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + self.proj = StructuredLinear( + hidden, + dim, + bias=False, + linear_type=structured_linear_type, + btt_rank=btt_rank, + in_shape_spec=btt_out_shape, + out_shape_spec=btt_in_shape, + btt_init=btt_init, + btt_init_gain=btt_init_gain, + chunk_ranks=structured_chunk_ranks, + eval_chunk_ranks=structured_eval_chunk_ranks, + ) + if compile_structured_mlp: + compile_kwargs = dict(fullgraph=False, dynamic=False) + if structured_compile_mode == "reduce-overhead": + compile_options = dict(torch._inductor.list_mode_options("reduce-overhead")) + compile_options["triton.cudagraphs"] = False + compile_kwargs["options"] = compile_options + else: + compile_kwargs["mode"] = structured_compile_mode + self.fc = torch.compile( + self.fc, + **compile_kwargs, + ) + self.proj = torch.compile( + self.proj, + **compile_kwargs, + ) + else: + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP( + dim, + mlp_mult, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + structured_mlp: bool, + structured_linear_type: str, + btt_rank: int, + btt_in_shape: str, + btt_out_shape: str, + btt_init: str, + btt_init_gain: float, + compile_structured_mlp: bool, + structured_chunk_ranks: int, + structured_eval_chunk_ranks: int, + structured_compile_mode: str, + rate_target_tensors: str, + rate_temperature: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.rate_target_tensors = rate_target_tensors + self.rate_temperature = rate_temperature + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + structured_mlp, + structured_linear_type, + btt_rank, + btt_in_shape, + btt_out_shape, + btt_init, + btt_init_gain, + compile_structured_mlp, + structured_chunk_ranks, + structured_eval_chunk_ranks, + structured_compile_mode, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + if isinstance(module, StructuredLinear) and getattr(module, "_zero_init", False): + if module.impl is not None: + nn.init.zeros_(module.impl.weight) + else: + nn.init.zeros_(module.core2) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def rate_regularizer(self) -> tuple[Tensor, Tensor, Tensor]: + total_symbols = 0.0 + weighted_rate = None + weighted_scale = None + estimated_bits = None + for name, param in self.named_parameters(): + if not should_rate_regularize_tensor(name, param, self.rate_target_tensors): + continue + rate_proxy, scale_proxy, tensor_bits = estimate_tensor_rate_proxy(param, self.rate_temperature) + weight = float(param.numel()) + total_symbols += weight + if weighted_rate is None: + weighted_rate = rate_proxy * weight + weighted_scale = scale_proxy * weight + estimated_bits = tensor_bits + else: + weighted_rate = weighted_rate + rate_proxy * weight + weighted_scale = weighted_scale + scale_proxy * weight + estimated_bits = estimated_bits + tensor_bits + if weighted_rate is None or weighted_scale is None or estimated_bits is None: + zero = self.tok_emb.weight.new_zeros(()) + return zero, zero, zero + denom = max(total_symbols, 1.0) + return weighted_rate / denom, weighted_scale / denom, estimated_bits + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if args.enable_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + simulated_world_size = args.simulated_world_size if not distributed and args.simulated_world_size > 0 else 0 + if simulated_world_size > 0: + world_size = simulated_world_size + rank = args.simulated_rank + local_rank = 0 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if rank < 0 or rank >= world_size: + raise ValueError(f"RANK={rank} must satisfy 0 <= RANK < WORLD_SIZE={world_size}") + if simulated_world_size > 1 and not args.dry_run_init_only: + raise ValueError("SIMULATED_WORLD_SIZE>1 is only supported together with DRY_RUN_INIT_ONLY=1") + if args.train_batch_tokens <= 0: + raise ValueError(f"TRAIN_BATCH_TOKENS must be positive, got {args.train_batch_tokens}") + train_microbatch_tokens = args.train_microbatch_tokens if args.train_microbatch_tokens > 0 else max( + args.train_seq_len, args.train_batch_tokens // 8 + ) + if train_microbatch_tokens <= 0: + raise ValueError(f"TRAIN_MICROBATCH_TOKENS must be positive, got {train_microbatch_tokens}") + if train_microbatch_tokens % args.train_seq_len != 0: + raise ValueError( + f"TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens} must be divisible by TRAIN_SEQ_LEN={args.train_seq_len}" + ) + tokens_per_microstep = world_size * train_microbatch_tokens + if args.grad_accum_steps > 0: + grad_accum_steps = args.grad_accum_steps + effective_train_batch_tokens = tokens_per_microstep * grad_accum_steps + else: + if args.train_batch_tokens % tokens_per_microstep != 0: + raise ValueError( + "TRAIN_BATCH_TOKENS must be divisible by WORLD_SIZE * TRAIN_MICROBATCH_TOKENS " + f"unless GRAD_ACCUM_STEPS is set explicitly; got TRAIN_BATCH_TOKENS={args.train_batch_tokens}, " + f"WORLD_SIZE={world_size}, TRAIN_MICROBATCH_TOKENS={train_microbatch_tokens}" + ) + grad_accum_steps = args.train_batch_tokens // tokens_per_microstep + effective_train_batch_tokens = args.train_batch_tokens + if grad_accum_steps <= 0: + raise ValueError(f"GRAD_ACCUM_STEPS must be positive, got {grad_accum_steps}") + grad_scale = 1.0 / grad_accum_steps + batch_multiplier = effective_train_batch_tokens / max(args.lr_reference_batch_tokens, 1) + effective_warmup_steps = scaled_warmup_steps( + args.warmup_steps, + effective_train_batch_tokens / max(args.warmup_reference_batch_tokens, 1), + args.warmup_scale_method, + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + if args.run_id == "train": + logfile = "train.log" + else: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + if args.val_token_limit > 0: + usable = min(((args.val_token_limit) // args.train_seq_len) * args.train_seq_len, val_tokens.numel() - 1) + if usable <= 0: + raise ValueError(f"VAL_TOKEN_LIMIT={args.val_token_limit} is too small for TRAIN_SEQ_LEN={args.train_seq_len}") + val_tokens = val_tokens[: usable + 1].contiguous() + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.val_token_limit > 0: + log0(f"val_token_limit:{args.val_token_limit}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + structured_mlp=args.structured_mlp, + structured_linear_type=args.structured_linear_type if args.structured_mlp else "dense", + btt_rank=args.btt_rank, + btt_in_shape=args.btt_in_shape, + btt_out_shape=args.btt_out_shape, + btt_init=args.btt_init, + btt_init_gain=args.btt_init_gain, + compile_structured_mlp=args.compile_structured_mlp, + structured_chunk_ranks=args.structured_chunk_ranks, + structured_eval_chunk_ranks=args.structured_eval_chunk_ranks, + structured_compile_mode=args.structured_compile_mode, + rate_target_tensors=args.rate_target_tensors, + rate_temperature=args.rate_temperature, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, (CastedLinear, StructuredLinear)): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if args.enable_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + scaled_token_lr = scale_from_batch_multiplier(token_lr, batch_multiplier, args.lr_scale_method) + scaled_head_lr = scale_from_batch_multiplier(args.head_lr, batch_multiplier, args.lr_scale_method) + scaled_matrix_lr = scale_from_batch_multiplier(args.matrix_lr, batch_multiplier, args.lr_scale_method) + scaled_scalar_lr = scale_from_batch_multiplier(args.scalar_lr, batch_multiplier, args.lr_scale_method) + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": scaled_token_lr, "base_lr": scaled_token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=scaled_matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = scaled_matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": scaled_scalar_lr, "base_lr": scaled_scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": scaled_head_lr, "base_lr": scaled_head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0( + f"world_size:{world_size} grad_accum_steps:{grad_accum_steps} " + f"train_microbatch_tokens:{train_microbatch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens}" + ) + if simulated_world_size > 0: + log0(f"simulated_world_size:{simulated_world_size} simulated_rank:{rank}") + log0(f"enable_compile:{args.enable_compile}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"quant_scheme:{args.quant_scheme} compressor:{args.compressor} level:{args.compressor_level} " + f"structured_mlp:{args.structured_mlp} structured_linear_type:{args.structured_linear_type} " + f"btt_rank:{args.btt_rank} btt_init:{args.btt_init} compile_structured_mlp:{args.compile_structured_mlp} " + f"structured_chunk_ranks:{args.structured_chunk_ranks} structured_eval_chunk_ranks:{args.structured_eval_chunk_ranks} " + f"structured_compile_mode:{args.structured_compile_mode}" + ) + log0( + f"rate_lambda:{args.rate_lambda} scale_lambda:{args.scale_lambda} " + f"rate_warmup_start_fraction:{args.rate_warmup_start_fraction} rate_full_fraction:{args.rate_full_fraction} " + f"rate_target_tensors:{args.rate_target_tensors} rate_schedule_mode:{args.rate_schedule_mode} " + f"max_expected_steps:{args.max_expected_steps}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{scaled_token_lr} " + f"head_lr:{scaled_head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{scaled_matrix_lr} scalar_lr:{scaled_scalar_lr}" + ) + log0( + f"train_batch_tokens_target:{args.train_batch_tokens} effective_train_batch_tokens:{effective_train_batch_tokens} " + f"train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{effective_warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0( + f"lr_scale_method:{args.lr_scale_method} lr_reference_batch_tokens:{args.lr_reference_batch_tokens} " + f"batch_multiplier:{batch_multiplier:.4f} warmup_scale_method:{args.warmup_scale_method} " + f"warmup_reference_batch_tokens:{args.warmup_reference_batch_tokens}" + ) + log0(f"seed:{args.seed}") + + if args.debug_data_sharding_steps > 0: + debug_log_data_sharding( + args.train_files, + world_size, + effective_train_batch_tokens, + args.train_seq_len, + grad_accum_steps, + args.debug_data_sharding_steps, + log0, + ) + if args.dry_run_init_only: + log0("dry_run_init_only: exiting after init/logging") + return + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + def rate_progress(step: int, elapsed_ms: float) -> float: + if args.rate_schedule_mode == "steps": + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + if args.rate_schedule_mode == "expected_steps": + if args.max_expected_steps <= 0: + raise ValueError("MAX_EXPECTED_STEPS must be positive when RATE_SCHEDULE_MODE=expected_steps") + return step / max(args.max_expected_steps, 1) + if args.rate_schedule_mode == "wallclock": + if max_wallclock_ms is None: + raise ValueError("RATE_SCHEDULE_MODE=wallclock requires MAX_WALLCLOCK_SECONDS > 0") + return elapsed_ms / max(max_wallclock_ms, 1e-9) + if args.rate_schedule_mode != "auto": + raise ValueError(f"Unsupported RATE_SCHEDULE_MODE={args.rate_schedule_mode}") + if args.max_expected_steps > 0: + return step / max(args.max_expected_steps, 1) + if max_wallclock_ms is not None and step > 0: + step_ms = elapsed_ms / max(step, 1) + expected_steps = max(int(max_wallclock_ms / max(step_ms, 1e-9)), 1) + return step / expected_steps + if args.iterations <= 0: + return 1.0 + return step / max(args.iterations, 1) + + def rate_schedule(step: int, elapsed_ms: float) -> float: + if args.rate_lambda <= 0.0 and args.scale_lambda <= 0.0: + return 0.0 + frac = rate_progress(step, elapsed_ms) + start = args.rate_warmup_start_fraction + full = max(args.rate_full_fraction, start + 1e-6) + if frac <= start: + return 0.0 + if frac >= full: + return 1.0 + return (frac - start) / (full - start) + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if effective_warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(effective_warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if ( + effective_warmup_steps <= 20 + or (warmup_step + 1) % 10 == 0 + or warmup_step + 1 == effective_warmup_steps + ): + log0(f"warmup_step:{warmup_step + 1}/{effective_warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + model.train() + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + train_ce_loss = torch.zeros((), device=device) + train_rate_proxy = torch.zeros((), device=device) + train_scale_proxy = torch.zeros((), device=device) + train_est_bits = torch.zeros((), device=device) + rate_mul = rate_schedule(step, elapsed_ms) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(effective_train_batch_tokens, args.train_seq_len, grad_accum_steps) + maybe_mark_compile_step_begin(args.enable_compile or args.compile_structured_mlp) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + total_loss = ce_loss + if rate_mul > 0.0: + rate_proxy, scale_proxy, est_bits = base_model.rate_regularizer() + total_loss = total_loss + rate_mul * (args.rate_lambda * rate_proxy + args.scale_lambda * scale_proxy) + train_rate_proxy += rate_proxy.detach() + train_scale_proxy += scale_proxy.detach() + train_est_bits += est_bits.detach() + train_ce_loss += ce_loss.detach() + train_loss += total_loss.detach() + (total_loss * grad_scale).backward() + train_loss /= grad_accum_steps + train_ce_loss /= grad_accum_steps + train_rate_proxy /= grad_accum_steps + train_scale_proxy /= grad_accum_steps + train_est_bits /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} ce_loss:{train_ce_loss.item():.4f} " + f"rate_mul:{rate_mul:.3f} rate_proxy_bits:{train_rate_proxy.item():.4f} " + f"scale_proxy:{train_scale_proxy.item():.6f} est_model_bits:{train_est_bits.item():.0f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict(base_model.state_dict(), args.quant_scheme) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = compress_blob(quant_raw, args.compressor, args.compressor_level) + quant_raw_bytes = len(quant_raw) + quant_filename = f"final_model.{args.quant_scheme}.{args.compressor}.ptc" + if master_process: + with open(quant_filename, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(quant_filename) + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {args.quant_scheme}+{args.compressor}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {args.quant_scheme}+{args.compressor}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(quant_filename, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(decompress_blob(quant_blob_disk, args.compressor)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{args.quant_scheme}_{args.compressor}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{args.quant_scheme}_{args.compressor}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()