diff --git a/records/track_10min_16mb/2026-04-01_random_mlp/README.md b/records/track_10min_16mb/2026-04-01_random_mlp/README.md new file mode 100644 index 0000000000..47b22a2a0b --- /dev/null +++ b/records/track_10min_16mb/2026-04-01_random_mlp/README.md @@ -0,0 +1,118 @@ +# Non-Record: Partially Random MLP + +Not a record run, but a proof of concept that partially random MLP layers are competitive — and likely more so with a cleaner implementation stack than the one I had time to build. + +**val_bpb: 1.1527**: (3-seed mean, std. 0.0021) | ~**15.95MB** under zlib | 8xH100 SXM, 600s | No TTT + +One oversight worth calling out: only one of my three submissions uses zstd for compression. Switching from `zlib` to `zstd` is worth roughly ~1.49MB of artifact headroom - enough to fit one or two additional fully learned layers. The average compression figure above uses the larger zlib artifacts for fairness, but this is low-hanging fruit for anyone building on this. + +## Results + +| Seed | Steps | ms/step | Pre-quant BPB | **Sliding BPB** | Artifact | +|--------- |-------|---------|---------------|-----------------|------------| +| 1337 | 4,962 | 120.93 | 1.1661 | **1.1516** | 15,979,644 | +| 2026 | 4,973 | 120.66 | 1.1656 | **1.1508** | 15,979,644 | +| 999 | 4,598 | 130.50 | 1.1687 | **1.1558** | 14,454,547 | +| **Mean** | | | | **1.1527** | | + +--- + +## Core Idea +Any parameter that is computable at initialization time is effectively free — it costs nothing in the artifact budget. This observation is obvious in hindsight, but it took a detour through HRM and TRM experimentation to arrive at it clearly.\ +While it's much harder to guess correct initialization values for the attention subsystem, the MLP blocks are much better understood in the sense, that we have at least some intuition as to what they're doing - some mix of content-addressable memory and feature extraction.\ +Since the up-projection of the MLP is what determines the features to be detected, there's less of an argument that it needs to be fully learned from scratch. +Any well-chosen random basis should work as a fixed feature extractor and leaving down-projection and surrounding machinery to do the adaptation. +The resulting construction is simple: at initialization, selected MLP up-projections are sampled from a random matrix and frozen. Only the seed is stored - the weights are recomputed at load time and never saved. Each frozen layer additionally gets a learnable per-feature gain vector (initialized to all ones), which gives the model a cheap learned scaling on top of the fixed basis. +The weight saving ends up anecdotally being around ~0.7MB per random layer, depending on training progress, and a working `zstd` stack may be able to squeeze in one or two more fully trainable layers. +The freed parameter budget was reinvested in model depth, landing at 12 layers total: 5 random, 7 learned. + +Note that this does not reduce compute — it trades parameter storage for additional depth under a fixed budget. + + +### Initializing the random projections +I experimented with several initialization schemes, scaled normal, Rademacher, and QR. +QR won consistently on my local 3090 iso-step ablations, and it's simple to construct: sample a random matrix using a fixed seed (or generator when doing this over multiple layers), compute the QR-decomposition, scale the resulting Q by `sqrt(d_in)` and use it as the up-projection. +My intuition for this is structural - QR yields an almost guaranteed orthogonal basis (okay, in all but pathological cases, and then you could still use rejection sampling to guarantee it) meaning the random features are well behaved - maximally space-covering without any redundancy. The right prior for a feature extractor you lock in at init and never update. And still having a learned down-projection means that mixing between features is fully learnable. + +One interesting observation from local ablations on my 3090 under iso-steps: In early training, between roughly steps 400–1000 (bsz=16k, after initial loss settling), models with random MLP layers *temporarily outperform* fully learned models. My interpretation is that the frozen projections provide a stable feature basis that the rest of the model can organize around quickly - learned layers then route through this fixed scaffold rather than having to discover a useful basis from scratch. The advantage narrows as learned layers catch up and eventually both variants end up with quite similar loss-curves, but it doesn't necessarily hurt. + +### Additional constructions around random layers +During early experimentation, without QR-init, the thought came up that potentially the randomly initialized matrices were rank-deficient, or ill-conditioned in the sense that some learned features may be co-linear. To address this problem, I added what I call a "mini-MoE" construction, where multiple random up-projections are performed and the model learns a token-dependent router that adjusts their relative weights. This construction did improve performance (even with QR-init), but I removed it during my final H100 runs because under an iso-wallclock setting they didn't help. If someone can reduce the throughput-cost, this remains a viable direction. + +## How I got here: HRM/TRM +Before arriving at random reservoirs, I spent time experimenting with [HRM](https://arxiv.org/abs/2506.21734) and [TRM](https://arxiv.org/abs/2510.04871) inspired looped/repeated layer constructions, motivated by interest in what happens when you push the number of inner settling steps large enough that the model stops doing fixed-point iteration and starts exploring the latent space more freely. + +**TLDR**: they didn't work well for this challenge, and I don't think that's surprising in retrospect. Looped constructions seem to favor settings where algorithmic, per-token "reasoning" effort pays off. +The canonical example for the HRM paper is something like Sudoku, where you can fit the entire problem space into the model's latent dimension and iterative refinement of a pre-existing solution works out. General language modeling however does not fit well into this structure. + +The underlying issue here, and why I wanted to experiment with decoupled settling steps in the first place, may be gradient flow. +In traditional looped language model constructions, even outside of HRM/TRM, e.g. [Ouro](https://arxiv.org/abs/2510.25741), [MobileLLM](https://arxiv.org/abs/2402.14905) or the original [Universal transformers](https://arxiv.org/abs/1807.03819) gradients flow through the repeated layers and there's optimization pressure that pushes intermediate solutions directionally towards the final solution. +I believe this yields convergent, and in the case of HRMs explicitly fixed-point iteration-like behavior with steadily decaying residuals - to be fair, I don't know if this aspect applies to looped language models as well, but it does to HRMs. +Pushing the non-gradiented steps may be a way of forcing the model into a compositional instead of iterative regime, resulting in actual latent-space exploration, beyond the pressures of intermediate steps having to directionally align with the intended target. +I didn't have the compute budget to test this seriously, and beyond that I'm still somewhat sceptical if this is really the way to go about this - you do need to allow intermediate layers more freedom, but having this arise from repetition requires "traces of composability" to be hidden within the model that get amplified by training. And then again those traces have to be strong enough that gradient descent can lock onto and amplify them. + +Cross-attention coupling (as an alternative to elementwise-additive mixing in HRM/TRM) did improve performance over additive coupling on my local 3090 - but showed no clear advantage over a dense baseline at the scales relevant to this competition. The experiments were useful for ruling out that direction quickly. + +--- + +## Other components +Beyond the random MLP architecture, I ported a number of tricks from [PR #414](https://github.com/openai/parameter-golf/pull/414) the current SOTA submission at the time of my experimentation. In short: + +- MLP 3x +- efficient XSA on the last 4 layers, tested and won against XSA-2 and XSA-3 +- partial RoPE +- LN Scale +- VE 128 on the last 3 layers, no significant differences seen vs. last 4 +- BigramHash(2048) + Smeargate +- int6 QAT with STE +- EMA + late SWA + +--- + +## TTT + +Legal test-time training was enabled but did not improve post-quantization performance. This appears to be a known issue - the new SOTA submission at the time of writing, [PR #1019](https://github.com/openai/parameter-golf/pull/1019) mentions similar behavior. +Although it's unclear if they saw the same catastrophic results as I did, in my experimentation loss increased substantially and while it did show signs of decreasing it still ends up being higher than at TTT start (bpb jumped from 1.186077 to 1.428367). +There's still time budget remaining in eval, so this is likely fixable for someone willing to debug it carefully. If it's useful to someone, I can share the final non- and quantized models for further experimentation. + +## Run Command + +```bash +ITERATIONS=20000 \ +TRAIN_BATCH_TOKENS=786432 \ +TRAIN_SEQ_LEN=2048 \ +VAL_LOSS_EVERY=0 \ +VAL_BATCH_SIZE=786432 \ +WARMDOWN_ITERS=3500 \ +NUM_LAYERS=12 \ +TRAIN_LOG_EVERY=10 \ +MLP_MULT=3 \ +RAND_PROJ_LAYERS="0,1,2,3,4" \ +RAND_GAIN=1 \ +RAND_INIT_QR=1 \ +MINI_MOE_EXPERTS=1 \ +VE_LAYERS="9,10,11" \ +VE_DIM=128 \ +XSA_LAYERS="8,9,10,11" \ +BIGRAM_VOCAB_SIZE=2048 \ +SEED=1337 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +An explanation on unusual hyperparameters: +- RAND_PROJ_LAYERS configures which layers use the random up-projection +- RAND_GAIN - enables (1) or disables the learnable per-feature weighting after projection +- RAND_INIT_QR - enables/disables the QR-init, falls back to normal init otherwise +- MINI_MOE_EXPERTS - the number of up-projections to generate per random MLP at 1 it removes the router and expert-gating altogether + + +## Future directions +The most natural extension of this idea is going all in on the `Learning adapters on random linear maps` angle this idea falls into. +My current bet is that random up- **and** down-projections may work for some layers if you stack the feature weighting and potentially some LoRA style adapters ontop of them - cheap, potentially expressive, early feature detectors that have little to no learned parameters. +A potential constraint that makes this interesting - you can theoretically compute the up- and down-projections to be pseudo-inverses, then learn diagonal scaling (the current random-gain, per-feature weighting) and a low-rank correction ontop. + + +Aside from this, there's TTT debugging and further exploration of the mini-MoE idea. +If anyone wants to debug the TTT further, I can share the trained model checkpoints. That should make it possible to isolate whether the failure is in the quantization interaction, the random layer gradient issue, or something else entirely. + +Since I likely won't have the time to run more experiments (and running experiments on 8xH100s is quite expensive), feel free to expand and build off of the ideas here! diff --git a/records/track_10min_16mb/2026-04-01_random_mlp/baseline_no_moe_xsa4_ve3_12l_5r7l_seed1337.txt b/records/track_10min_16mb/2026-04-01_random_mlp/baseline_no_moe_xsa4_ve3_12l_5r7l_seed1337.txt new file mode 100644 index 0000000000..053c636f05 --- /dev/null +++ b/records/track_10min_16mb/2026-04-01_random_mlp/baseline_no_moe_xsa4_ve3_12l_5r7l_seed1337.txt @@ -0,0 +1,2812 @@ +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +from flash_attn_interface import flash_attn_func +# from flash_attn import flash_attn_func + +# make dynamo less complainy +import torch._dynamo +torch._dynamo.config.cache_size_limit = 64 + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # random projection control + rand_proj_layers = [int(x) for x in os.environ.get("RAND_PROJ_LAYERS", "").split(",") if x] + rand_gain = bool(int(os.environ.get("RAND_GAIN", "0"))) + mini_moe_experts = int(os.environ.get("MINI_MOE_EXPERTS", 1)) + rand_init_qr = bool(int(os.environ.get("RAND_INIT_QR", "1"))) + + # xsa control + xsa_layers = [int(x) for x in os.environ.get("XSA_LAYERS", "").split(",") if x] + + # quant + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + + # partial rope + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + + # ln scale + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + + # bigram control + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # value embedding control + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = [int(x) for x in os.environ.get("VE_LAYERS", "9,10").split(",") if x] + + # ttt + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + + # swa + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# --- sliding window eval --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # workaround for caching issues with sin/cos in rotary layers: + for block in base_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + +# --- int6 quant as per #414 --- +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + # TODO: not all mlp params likely want int6, e.g. scales and gates + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6( + state_dict: dict[str, Tensor], + int6_cats: set[str] +): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6( + result: dict[str, Tensor], + meta: dict[str, object], + template_sd: dict[str, Tensor] +) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # straight-through qat as per #137 + qat_ste: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self.qat_ste and self.training: + # STE fake int6: quantize to [-32,31] and dequantize, gradient flows through + w32 = w.float() + row_max = w32.abs().amax(dim=1).clamp_min(1e-8) + scale = row_max / 31.0 + w_q = torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None] + w = w + (w_q.to(w.dtype) - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +# using partial RoPE as per #414 +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + use_xsa: bool = False, + train_seq_len: int = 1024, + rope_dims: int | None = None + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, rope_dims=rope_dims) + self.use_xsa = use_xsa + + # efficient XSA as per + # https://github.com/unnir/parameter-golf/blob/a81f85bd7f632e3e48ef6b1da0017b81d25998a7/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/train_gpt.py + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + # Reshape y into KV head groups — free view, no memory alloc + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + # Project out self-value component per KV head group + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rotary.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rotary.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_func(q, k, v, causal=True) + + # efficient xsa as per https://github.com/openai/parameter-golf/pull/265 + if self.use_xsa: + y = self._xsa_efficient(y, v) + + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# stolen from #414 +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +# stolen from #414 +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +# stolen from #414 +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__( + self, + dim: int, + mlp_mult: int, + rng_up: bool = False, + use_rng_gain: bool = False, + mini_moe_experts: int = 1, + ): + super().__init__() + self.hidden = mlp_mult * dim + self.mini_moe_experts = mini_moe_experts + self.use_rng_gain = use_rng_gain + self.fc: CastedLinear | None = None + self.rand_gain: nn.Parameter | None = None + self.moe_router: CastedLinear | None = None + if not rng_up: + self.fc = CastedLinear(dim, self.hidden, bias=False) + else: + # non-persistent buffer, because deterministic random init, kept in bf16 because not quant'd anyway + self.register_buffer("fc_w", torch.empty((mini_moe_experts, self.hidden, dim), dtype=torch.bfloat16), persistent=False) + if self.use_rng_gain: + self.rand_gain = nn.Parameter(torch.ones(mini_moe_experts, self.hidden, dtype=torch.float32)) if rng_up else None + if self.mini_moe_experts > 1: + self.moe_router = CastedLinear(dim, mini_moe_experts, bias=False) + self.moe_router._zero_init = True + self.proj = CastedLinear(self.hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.fc is not None: + x = self.fc(x) + else: + if self.mini_moe_experts == 1: + x = F.linear(x, self.fc_w[0].to(dtype=x.dtype)) + if self.rand_gain is not None: + x = x * self.rand_gain[0].to(dtype=x.dtype)[None, :] + else: + # compute router + moe_weights = F.softmax(self.moe_router(x), dim=-1) + # compute individual expert outputs and weighted sum + # x: (bsz, seqlen, dim), fc_w: (mini_moe_experts, dim, hidden) + exout = torch.einsum("bsd,ehd->bseh", x, self.fc_w.to(dtype=x.dtype)) # (bsz, seqlen, mini_moe_experts, hidden) + if self.rand_gain is not None: + exout = exout * self.rand_gain.to(dtype=x.dtype)[None, None, :, :] + x = moe_weights.unsqueeze(-1) * exout + x = x.sum(dim=2) # (bsz, seqlen, hidden) + # x = torch.relu(x) + x = F.leaky_relu(x, negative_slope=0.5) + return self.proj(x.square()) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rng_up: bool = False, + use_rng_gain: bool = False, + mini_moe_experts: int = 1, + train_seq_len: int = 1024, + use_xsa: bool = False, + rope_dims: int | None = None, + ln_scale: bool = False, + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_xsa, train_seq_len, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult, rng_up, use_rng_gain, mini_moe_experts) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x) * self.ln_scale_factor, v_embed=v_embed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * self.ln_scale_factor) + return x + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + rand_proj_layers: list[int], + use_rng_gain: bool = False, + mini_moe_experts: int = 1, + rand_init_seed: int = 42, + rand_init_qr: bool = False, + xsa_layers: list[int] = [], + rope_dims: int | None = None, + ln_scale: bool = False, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + ve_dim: int = 128, + ve_layers: list[int] = [], + train_seq_len: int = 1024, + ): + super().__init__() + self.rand_init_seed = rand_init_seed + self.rand_init_qr = rand_init_qr + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.smear = SmearGate(model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_layers = num_layers + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rng_up=(i in rand_proj_layers), + use_rng_gain=use_rng_gain, + mini_moe_experts=mini_moe_experts, + train_seq_len=train_seq_len, + use_xsa=(i in xsa_layers), + rope_dims=rope_dims, + ln_scale=ln_scale, + layer_idx=i, + ) + for i in range(num_layers) + ] + ) + + # value embeddings + self.ve_layer_indices = ve_layers + self.ve_target_dim = num_kv_heads * (model_dim // num_heads) + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, self.ve_target_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + gen = torch.Generator() + gen.manual_seed(self.rand_init_seed) + + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * self.num_layers)) + + if isinstance(module, MLP) and getattr(module, "fc_w", None) is not None: + # perform seeded random init for the random fc weights + n_experts, d_out, d_in = module.fc_w.shape + if self.rand_init_qr: + for e in range(n_experts): + G = torch.randn((d_out, d_in), generator=gen) + q, _ = torch.linalg.qr(G) + module.fc_w[e].copy_(q) / math.sqrt(d_in) + else: + nn.init.normal_(module.fc_w, mean=0.0, std=1.0/math.sqrt(d_in), generator=gen) + # module.fc_w.bernoulli_(0.5, generator=gen).mul_(2).sub_(1).mul_(1.0 / math.sqrt(d_in)) + + # value-embeddins as per #414 + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict[str, Tensor] | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + ve_cache: dict[str, Tensor] = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + rand_proj_layers=args.rand_proj_layers, + use_rng_gain=args.rand_gain, + mini_moe_experts=args.mini_moe_experts, + train_seq_len=args.train_seq_len, + xsa_layers=args.xsa_layers, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + # compiled_model = base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"num_layers:{args.num_layers} model_dim:{args.model_dim} num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads} mlp_mult:{args.mlp_mult}") + log0(f"rand_proj_layers:{args.rand_proj_layers} rand_gain:{args.rand_gain} mini_moe_experts:{args.mini_moe_experts}") + log0(f"bigram_vocab_size:{args.bigram_vocab_size} bigram_dim:{args.bigram_dim} ve_dim:{args.ve_dim} ve_layers:{args.ve_layers}") + log0(f"tie_embeddings:{args.tie_embeddings} tied_embed_init_std:{args.tied_embed_init_std} logit_softcap:{args.logit_softcap}") + log0(f"rope_base:{args.rope_base} qk_gain_init:{args.qk_gain_init} rope_dims:{args.rope_dims} xsa_layers:{args.xsa_layers} ln_scale:{args.ln_scale}") + log0(f"xsa_layers:{args.xsa_layers} ve_layers:{args.ve_layers} ve_dim:{args.ve_dim}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0(f"mlp_mode:{'rng_up' if args.rand_proj_layers else 'standard'} mlp_mult:{args.mlp_mult} mini_moe_experts:{args.mini_moe_experts} use_rng_gain:{args.rand_gain}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + # enable qat during warmup so we don't pay the compilation tax later + if warmup_step == 2: + CastedLinear.qat_ste = True + if warmup_step == 4: + CastedLinear.qat_ste = False + model.eval() + if warmup_step == 6: + CastedLinear.qat_ste = True + if warmup_step == 8: + model.train() + CastedLinear.qat_ste = False + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + # SWA as per #414 + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # EMA as per #414 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + + val_bpb = float("inf") + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # late QAT as per #414 + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear.qat_ste: + CastedLinear.qat_ste = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + # EMA update as per #414 + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA as per #414 + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply EMA weights (better than SWA alone per PR#401) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + + if diag_val_bpb > val_bpb: + log0( + f"EMA did not improve val_bpb ({val_bpb:.4f} -> {diag_val_bpb:.4f}), restoring pre-EMA weights for final serialization" + ) + base_model.load_state_dict(current_state, strict=True) + else: + log0( + f"EMA improved val_bpb ({val_bpb:.4f} -> {diag_val_bpb:.4f}), keeping EMA weights for final serialization" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # --- quant --- + sd_cpu = {k: v.cpu() for k, v in base_model.state_dict().items()} + # TODO: think a/b keeping routers in fp32 or higher? + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + # quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + rand_proj_layers=args.rand_proj_layers, + use_rng_gain=args.rand_gain, + mini_moe_experts=args.mini_moe_experts, + train_seq_len=args.train_seq_len, + xsa_layers=args.xsa_layers, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + compiled_eval, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int6_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # --- sliding window + TTT eval --- + + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + # disable QAT ste for TTT + CastedLinear.qat_ste = False + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Mon Mar 30 19:51:36 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 34C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 36C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:25431141 +num_layers:12 model_dim:512 num_heads:8 num_kv_heads:4 mlp_mult:3 +rand_proj_layers:[0, 1, 2, 3, 4] rand_gain:True mini_moe_experts:1 +bigram_vocab_size:2048 bigram_dim:128 ve_dim:128 ve_layers:[9, 10, 11] +tie_embeddings:True tied_embed_init_std:0.005 logit_softcap:30.0 +rope_base:10000.0 qk_gain_init:1.5 rope_dims:16 xsa_layers:[8, 9, 10, 11] ln_scale:True +xsa_layers:[8, 9, 10, 11] ve_layers:[9, 10, 11] ve_dim:128 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +mlp_mode:rng_up mlp_mult:3 mini_moe_experts:1 use_rng_gain:True +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9321 train_time:233ms step_avg:233.49ms +step:2/20000 train_loss:8.5083 train_time:354ms step_avg:176.84ms +step:3/20000 train_loss:8.0173 train_time:474ms step_avg:157.84ms +step:4/20000 train_loss:7.1221 train_time:594ms step_avg:148.38ms +step:5/20000 train_loss:6.9875 train_time:713ms step_avg:142.52ms +step:6/20000 train_loss:6.8634 train_time:832ms step_avg:138.68ms +step:7/20000 train_loss:6.7422 train_time:951ms step_avg:135.91ms +step:8/20000 train_loss:6.7197 train_time:1071ms step_avg:133.89ms +step:9/20000 train_loss:6.4320 train_time:1191ms step_avg:132.33ms +step:10/20000 train_loss:6.1091 train_time:1310ms step_avg:131.03ms +step:20/20000 train_loss:4.9422 train_time:2508ms step_avg:125.40ms +step:30/20000 train_loss:4.1990 train_time:3701ms step_avg:123.38ms +step:40/20000 train_loss:3.8905 train_time:4904ms step_avg:122.61ms +step:50/20000 train_loss:3.6870 train_time:6099ms step_avg:121.98ms +step:60/20000 train_loss:3.4308 train_time:7295ms step_avg:121.58ms +step:70/20000 train_loss:3.4507 train_time:8491ms step_avg:121.30ms +step:80/20000 train_loss:3.3128 train_time:9699ms step_avg:121.24ms +step:90/20000 train_loss:3.1735 train_time:10896ms step_avg:121.06ms +step:100/20000 train_loss:3.1978 train_time:12091ms step_avg:120.91ms +step:110/20000 train_loss:3.0955 train_time:13296ms step_avg:120.87ms +step:120/20000 train_loss:3.1315 train_time:14491ms step_avg:120.76ms +step:130/20000 train_loss:3.0049 train_time:15748ms step_avg:121.14ms +step:140/20000 train_loss:2.9982 train_time:16955ms step_avg:121.11ms +step:150/20000 train_loss:2.9293 train_time:18153ms step_avg:121.02ms +step:160/20000 train_loss:2.9073 train_time:19358ms step_avg:120.99ms +step:170/20000 train_loss:2.7222 train_time:20556ms step_avg:120.92ms +step:180/20000 train_loss:2.7521 train_time:21763ms step_avg:120.90ms +step:190/20000 train_loss:2.7224 train_time:22958ms step_avg:120.83ms +step:200/20000 train_loss:2.4021 train_time:24158ms step_avg:120.79ms +step:210/20000 train_loss:2.6482 train_time:25354ms step_avg:120.73ms +step:220/20000 train_loss:2.6830 train_time:26550ms step_avg:120.68ms +step:230/20000 train_loss:2.6344 train_time:27756ms step_avg:120.68ms +step:240/20000 train_loss:2.6085 train_time:28953ms step_avg:120.64ms +step:250/20000 train_loss:2.4635 train_time:30158ms step_avg:120.63ms +step:260/20000 train_loss:2.5475 train_time:31426ms step_avg:120.87ms +step:270/20000 train_loss:2.4509 train_time:32619ms step_avg:120.81ms +step:280/20000 train_loss:2.5965 train_time:33813ms step_avg:120.76ms +step:290/20000 train_loss:2.4859 train_time:35025ms step_avg:120.78ms +step:300/20000 train_loss:2.5389 train_time:36220ms step_avg:120.73ms +step:310/20000 train_loss:2.5535 train_time:37413ms step_avg:120.69ms +step:320/20000 train_loss:2.4611 train_time:38611ms step_avg:120.66ms +step:330/20000 train_loss:2.4752 train_time:39807ms step_avg:120.63ms +step:340/20000 train_loss:2.5767 train_time:41002ms step_avg:120.59ms +step:350/20000 train_loss:2.5290 train_time:42198ms step_avg:120.57ms +step:360/20000 train_loss:2.2990 train_time:43404ms step_avg:120.57ms +step:370/20000 train_loss:2.5048 train_time:44601ms step_avg:120.54ms +step:380/20000 train_loss:2.3640 train_time:45808ms step_avg:120.55ms +step:390/20000 train_loss:2.4907 train_time:47073ms step_avg:120.70ms +step:400/20000 train_loss:2.4037 train_time:48282ms step_avg:120.71ms +step:410/20000 train_loss:2.3951 train_time:49488ms step_avg:120.70ms +step:420/20000 train_loss:2.3410 train_time:50685ms step_avg:120.68ms +step:430/20000 train_loss:2.3470 train_time:51887ms step_avg:120.67ms +step:440/20000 train_loss:2.3927 train_time:53084ms step_avg:120.65ms +step:450/20000 train_loss:2.3560 train_time:54291ms step_avg:120.65ms +step:460/20000 train_loss:2.4055 train_time:55500ms step_avg:120.65ms +step:470/20000 train_loss:2.4003 train_time:56700ms step_avg:120.64ms +step:480/20000 train_loss:2.2338 train_time:57899ms step_avg:120.62ms +step:490/20000 train_loss:2.3907 train_time:59107ms step_avg:120.63ms +step:500/20000 train_loss:2.3915 train_time:60305ms step_avg:120.61ms +step:510/20000 train_loss:2.3067 train_time:61570ms step_avg:120.73ms +step:520/20000 train_loss:2.3492 train_time:62769ms step_avg:120.71ms +step:530/20000 train_loss:2.2407 train_time:63973ms step_avg:120.70ms +step:540/20000 train_loss:2.3169 train_time:65174ms step_avg:120.69ms +step:550/20000 train_loss:2.3336 train_time:66383ms step_avg:120.70ms +step:560/20000 train_loss:2.3328 train_time:67580ms step_avg:120.68ms +step:570/20000 train_loss:2.3812 train_time:68785ms step_avg:120.68ms +step:580/20000 train_loss:2.4266 train_time:69986ms step_avg:120.67ms +step:590/20000 train_loss:2.2543 train_time:71194ms step_avg:120.67ms +step:600/20000 train_loss:2.3283 train_time:72394ms step_avg:120.66ms +step:610/20000 train_loss:2.3295 train_time:73603ms step_avg:120.66ms +step:620/20000 train_loss:2.3056 train_time:74801ms step_avg:120.65ms +step:630/20000 train_loss:2.2860 train_time:76001ms step_avg:120.64ms +step:640/20000 train_loss:2.8208 train_time:77266ms step_avg:120.73ms +step:650/20000 train_loss:2.3294 train_time:78473ms step_avg:120.73ms +step:660/20000 train_loss:2.5264 train_time:79673ms step_avg:120.72ms +step:670/20000 train_loss:2.2652 train_time:80872ms step_avg:120.70ms +step:680/20000 train_loss:2.2623 train_time:82071ms step_avg:120.69ms +step:690/20000 train_loss:2.2902 train_time:83273ms step_avg:120.69ms +step:700/20000 train_loss:2.3454 train_time:84473ms step_avg:120.68ms +step:710/20000 train_loss:2.2834 train_time:85680ms step_avg:120.68ms +step:720/20000 train_loss:2.3744 train_time:86878ms step_avg:120.66ms +step:730/20000 train_loss:2.1655 train_time:88084ms step_avg:120.66ms +step:740/20000 train_loss:2.2663 train_time:89292ms step_avg:120.66ms +step:750/20000 train_loss:2.3274 train_time:90491ms step_avg:120.65ms +step:760/20000 train_loss:2.4145 train_time:91693ms step_avg:120.65ms +step:770/20000 train_loss:2.2813 train_time:92963ms step_avg:120.73ms +step:780/20000 train_loss:2.2392 train_time:94172ms step_avg:120.73ms +step:790/20000 train_loss:2.1940 train_time:95372ms step_avg:120.72ms +step:800/20000 train_loss:2.2404 train_time:96573ms step_avg:120.72ms +step:810/20000 train_loss:2.1933 train_time:97772ms step_avg:120.71ms +step:820/20000 train_loss:2.2355 train_time:98971ms step_avg:120.70ms +step:830/20000 train_loss:2.1913 train_time:100170ms step_avg:120.69ms +step:840/20000 train_loss:2.3435 train_time:101370ms step_avg:120.68ms +step:850/20000 train_loss:2.2339 train_time:102569ms step_avg:120.67ms +step:860/20000 train_loss:2.0451 train_time:103767ms step_avg:120.66ms +step:870/20000 train_loss:2.2862 train_time:104976ms step_avg:120.66ms +step:880/20000 train_loss:2.2235 train_time:106175ms step_avg:120.65ms +step:890/20000 train_loss:2.2499 train_time:107384ms step_avg:120.66ms +step:900/20000 train_loss:2.1290 train_time:108658ms step_avg:120.73ms +step:910/20000 train_loss:2.1949 train_time:109857ms step_avg:120.72ms +step:920/20000 train_loss:2.2292 train_time:111057ms step_avg:120.71ms +step:930/20000 train_loss:2.2812 train_time:112266ms step_avg:120.72ms +step:940/20000 train_loss:2.4211 train_time:113464ms step_avg:120.71ms +step:950/20000 train_loss:2.2228 train_time:114662ms step_avg:120.70ms +step:960/20000 train_loss:2.1699 train_time:115872ms step_avg:120.70ms +step:970/20000 train_loss:2.4006 train_time:117072ms step_avg:120.69ms +step:980/20000 train_loss:2.2212 train_time:118273ms step_avg:120.69ms +step:990/20000 train_loss:2.2542 train_time:119481ms step_avg:120.69ms +step:1000/20000 train_loss:2.2793 train_time:120680ms step_avg:120.68ms +step:1010/20000 train_loss:2.0561 train_time:121879ms step_avg:120.67ms +step:1020/20000 train_loss:2.0969 train_time:123155ms step_avg:120.74ms +step:1030/20000 train_loss:2.2064 train_time:124360ms step_avg:120.74ms +step:1040/20000 train_loss:2.2515 train_time:125560ms step_avg:120.73ms +step:1050/20000 train_loss:2.2242 train_time:126759ms step_avg:120.72ms +step:1060/20000 train_loss:2.3277 train_time:127959ms step_avg:120.72ms +step:1070/20000 train_loss:2.2484 train_time:129160ms step_avg:120.71ms +step:1080/20000 train_loss:2.2077 train_time:130361ms step_avg:120.70ms +step:1090/20000 train_loss:2.1688 train_time:131561ms step_avg:120.70ms +step:1100/20000 train_loss:2.3262 train_time:132762ms step_avg:120.69ms +step:1110/20000 train_loss:2.2356 train_time:133962ms step_avg:120.69ms +step:1120/20000 train_loss:2.0995 train_time:135163ms step_avg:120.68ms +step:1130/20000 train_loss:2.3416 train_time:136363ms step_avg:120.67ms +step:1140/20000 train_loss:2.2512 train_time:137563ms step_avg:120.67ms +step:1150/20000 train_loss:2.2515 train_time:138832ms step_avg:120.72ms +step:1160/20000 train_loss:2.1469 train_time:140035ms step_avg:120.72ms +step:1170/20000 train_loss:2.1791 train_time:141237ms step_avg:120.72ms +step:1180/20000 train_loss:2.1436 train_time:142448ms step_avg:120.72ms +step:1190/20000 train_loss:2.2480 train_time:143648ms step_avg:120.71ms +step:1200/20000 train_loss:2.3574 train_time:144849ms step_avg:120.71ms +step:1210/20000 train_loss:2.2035 train_time:146051ms step_avg:120.70ms +step:1220/20000 train_loss:2.2319 train_time:147257ms step_avg:120.70ms +step:1230/20000 train_loss:2.1670 train_time:148458ms step_avg:120.70ms +step:1240/20000 train_loss:2.1669 train_time:149659ms step_avg:120.69ms +step:1250/20000 train_loss:2.2555 train_time:150868ms step_avg:120.69ms +step:1260/20000 train_loss:2.1738 train_time:152070ms step_avg:120.69ms +step:1270/20000 train_loss:2.1886 train_time:153272ms step_avg:120.69ms +step:1280/20000 train_loss:2.1415 train_time:154539ms step_avg:120.73ms +step:1290/20000 train_loss:2.1619 train_time:155747ms step_avg:120.73ms +step:1300/20000 train_loss:2.3647 train_time:156948ms step_avg:120.73ms +step:1310/20000 train_loss:2.0310 train_time:158149ms step_avg:120.72ms +step:1320/20000 train_loss:2.1884 train_time:159359ms step_avg:120.73ms +step:1330/20000 train_loss:2.1647 train_time:160567ms step_avg:120.73ms +step:1340/20000 train_loss:2.2551 train_time:161775ms step_avg:120.73ms +step:1350/20000 train_loss:2.1778 train_time:162976ms step_avg:120.72ms +step:1360/20000 train_loss:2.2520 train_time:164186ms step_avg:120.72ms +step:1370/20000 train_loss:2.2867 train_time:165396ms step_avg:120.73ms +step:1380/20000 train_loss:2.1020 train_time:166596ms step_avg:120.72ms +step:1390/20000 train_loss:2.1518 train_time:167811ms step_avg:120.73ms +step:1400/20000 train_loss:2.2208 train_time:169072ms step_avg:120.77ms +step:1410/20000 train_loss:2.1803 train_time:170275ms step_avg:120.76ms +step:1420/20000 train_loss:2.1281 train_time:171480ms step_avg:120.76ms +step:1430/20000 train_loss:2.1271 train_time:172681ms step_avg:120.76ms +step:1440/20000 train_loss:2.2142 train_time:173893ms step_avg:120.76ms +step:1450/20000 train_loss:2.2020 train_time:175094ms step_avg:120.75ms +step:1460/20000 train_loss:2.1460 train_time:176293ms step_avg:120.75ms +step:1470/20000 train_loss:2.2745 train_time:177493ms step_avg:120.74ms +step:1480/20000 train_loss:2.1214 train_time:178697ms step_avg:120.74ms +step:1490/20000 train_loss:2.1815 train_time:179904ms step_avg:120.74ms +step:1500/20000 train_loss:2.1952 train_time:181101ms step_avg:120.73ms +step:1510/20000 train_loss:2.3477 train_time:182300ms step_avg:120.73ms +step:1520/20000 train_loss:1.9990 train_time:183498ms step_avg:120.72ms +step:1530/20000 train_loss:2.0283 train_time:184769ms step_avg:120.76ms +step:1540/20000 train_loss:2.1454 train_time:185969ms step_avg:120.76ms +step:1550/20000 train_loss:2.1859 train_time:187170ms step_avg:120.75ms +step:1560/20000 train_loss:2.2475 train_time:188381ms step_avg:120.76ms +step:1570/20000 train_loss:2.1640 train_time:189580ms step_avg:120.75ms +step:1580/20000 train_loss:2.0497 train_time:190780ms step_avg:120.75ms +step:1590/20000 train_loss:2.1348 train_time:191991ms step_avg:120.75ms +step:1600/20000 train_loss:2.1949 train_time:193198ms step_avg:120.75ms +step:1610/20000 train_loss:2.2096 train_time:194398ms step_avg:120.74ms +step:1620/20000 train_loss:2.1812 train_time:195603ms step_avg:120.74ms +step:1630/20000 train_loss:2.3915 train_time:196803ms step_avg:120.74ms +step:1640/20000 train_loss:2.2098 train_time:198014ms step_avg:120.74ms +step:1650/20000 train_loss:1.9914 train_time:199214ms step_avg:120.74ms +step:1660/20000 train_loss:2.2167 train_time:200482ms step_avg:120.77ms +step:1670/20000 train_loss:1.9825 train_time:201691ms step_avg:120.77ms +step:1680/20000 train_loss:2.1957 train_time:202893ms step_avg:120.77ms +step:1690/20000 train_loss:2.1899 train_time:204095ms step_avg:120.77ms +step:1700/20000 train_loss:2.1870 train_time:205296ms step_avg:120.76ms +step:1710/20000 train_loss:2.2445 train_time:206497ms step_avg:120.76ms +step:1720/20000 train_loss:2.2613 train_time:207707ms step_avg:120.76ms +step:1730/20000 train_loss:2.1974 train_time:208907ms step_avg:120.76ms +step:1740/20000 train_loss:2.1481 train_time:210106ms step_avg:120.75ms +step:1750/20000 train_loss:2.1213 train_time:211307ms step_avg:120.75ms +step:1760/20000 train_loss:2.1146 train_time:212506ms step_avg:120.74ms +step:1770/20000 train_loss:2.1287 train_time:213711ms step_avg:120.74ms +step:1780/20000 train_loss:2.0292 train_time:214912ms step_avg:120.74ms +step:1790/20000 train_loss:2.1666 train_time:216179ms step_avg:120.77ms +step:1800/20000 train_loss:2.1349 train_time:217389ms step_avg:120.77ms +step:1810/20000 train_loss:2.0713 train_time:218589ms step_avg:120.77ms +step:1820/20000 train_loss:2.0272 train_time:219789ms step_avg:120.76ms +step:1830/20000 train_loss:2.1589 train_time:220999ms step_avg:120.76ms +step:1840/20000 train_loss:2.2459 train_time:222210ms step_avg:120.77ms +step:1850/20000 train_loss:2.1501 train_time:223410ms step_avg:120.76ms +step:1860/20000 train_loss:2.0500 train_time:224614ms step_avg:120.76ms +step:1870/20000 train_loss:2.1037 train_time:225815ms step_avg:120.76ms +step:1880/20000 train_loss:2.0506 train_time:227015ms step_avg:120.75ms +step:1890/20000 train_loss:2.1249 train_time:228223ms step_avg:120.75ms +step:1900/20000 train_loss:2.1889 train_time:229433ms step_avg:120.75ms +step:1910/20000 train_loss:2.1674 train_time:230703ms step_avg:120.79ms +step:1920/20000 train_loss:2.1676 train_time:231903ms step_avg:120.78ms +step:1930/20000 train_loss:2.2080 train_time:233106ms step_avg:120.78ms +step:1940/20000 train_loss:2.0890 train_time:234306ms step_avg:120.78ms +step:1950/20000 train_loss:2.1233 train_time:235504ms step_avg:120.77ms +step:1960/20000 train_loss:2.0426 train_time:236714ms step_avg:120.77ms +step:1970/20000 train_loss:2.0665 train_time:237915ms step_avg:120.77ms +step:1980/20000 train_loss:2.1531 train_time:239116ms step_avg:120.77ms +step:1990/20000 train_loss:2.1106 train_time:240317ms step_avg:120.76ms +step:2000/20000 train_loss:2.1593 train_time:241527ms step_avg:120.76ms +step:2010/20000 train_loss:1.9517 train_time:242725ms step_avg:120.76ms +step:2020/20000 train_loss:2.1353 train_time:243925ms step_avg:120.75ms +step:2030/20000 train_loss:2.0430 train_time:245124ms step_avg:120.75ms +step:2040/20000 train_loss:2.2523 train_time:246398ms step_avg:120.78ms +step:2050/20000 train_loss:2.0840 train_time:247597ms step_avg:120.78ms +step:2060/20000 train_loss:2.1210 train_time:248797ms step_avg:120.78ms +step:2070/20000 train_loss:2.0760 train_time:249997ms step_avg:120.77ms +step:2080/20000 train_loss:2.1092 train_time:251204ms step_avg:120.77ms +step:2090/20000 train_loss:2.1557 train_time:252410ms step_avg:120.77ms +step:2100/20000 train_loss:2.0548 train_time:253606ms step_avg:120.76ms +step:2110/20000 train_loss:2.1472 train_time:254807ms step_avg:120.76ms +step:2120/20000 train_loss:2.1271 train_time:256008ms step_avg:120.76ms +step:2130/20000 train_loss:2.0711 train_time:257204ms step_avg:120.75ms +step:2140/20000 train_loss:2.1831 train_time:258400ms step_avg:120.75ms +step:2150/20000 train_loss:2.0276 train_time:259597ms step_avg:120.74ms +step:2160/20000 train_loss:2.0716 train_time:260803ms step_avg:120.74ms +step:2170/20000 train_loss:1.9844 train_time:262072ms step_avg:120.77ms +step:2180/20000 train_loss:2.1796 train_time:263273ms step_avg:120.77ms +step:2190/20000 train_loss:2.1989 train_time:264473ms step_avg:120.76ms +step:2200/20000 train_loss:2.1564 train_time:265683ms step_avg:120.76ms +step:2210/20000 train_loss:2.1156 train_time:266882ms step_avg:120.76ms +step:2220/20000 train_loss:1.8772 train_time:268082ms step_avg:120.76ms +step:2230/20000 train_loss:2.1177 train_time:269288ms step_avg:120.76ms +step:2240/20000 train_loss:1.8778 train_time:270489ms step_avg:120.75ms +step:2250/20000 train_loss:2.0941 train_time:271689ms step_avg:120.75ms +step:2260/20000 train_loss:2.1127 train_time:272890ms step_avg:120.75ms +step:2270/20000 train_loss:2.1667 train_time:274136ms step_avg:120.76ms +step:2280/20000 train_loss:2.1199 train_time:275339ms step_avg:120.76ms +step:2290/20000 train_loss:2.1768 train_time:276616ms step_avg:120.79ms +step:2300/20000 train_loss:2.0908 train_time:277825ms step_avg:120.79ms +step:2310/20000 train_loss:2.1151 train_time:279034ms step_avg:120.79ms +step:2320/20000 train_loss:2.0250 train_time:280236ms step_avg:120.79ms +step:2330/20000 train_loss:2.1939 train_time:281442ms step_avg:120.79ms +step:2340/20000 train_loss:2.0568 train_time:282653ms step_avg:120.79ms +step:2350/20000 train_loss:2.1278 train_time:283856ms step_avg:120.79ms +step:2360/20000 train_loss:2.0949 train_time:285057ms step_avg:120.79ms +step:2370/20000 train_loss:2.0273 train_time:286266ms step_avg:120.79ms +step:2380/20000 train_loss:2.1292 train_time:287468ms step_avg:120.78ms +step:2390/20000 train_loss:2.1119 train_time:288664ms step_avg:120.78ms +step:2400/20000 train_loss:2.1522 train_time:289862ms step_avg:120.78ms +step:2410/20000 train_loss:2.0631 train_time:291066ms step_avg:120.77ms +step:2420/20000 train_loss:2.0724 train_time:292330ms step_avg:120.80ms +step:2430/20000 train_loss:3.2393 train_time:293531ms step_avg:120.79ms +step:2440/20000 train_loss:2.0423 train_time:294739ms step_avg:120.79ms +step:2450/20000 train_loss:2.1341 train_time:295940ms step_avg:120.79ms +step:2460/20000 train_loss:2.1500 train_time:297138ms step_avg:120.79ms +step:2470/20000 train_loss:2.0909 train_time:298344ms step_avg:120.79ms +step:2480/20000 train_loss:2.0364 train_time:299543ms step_avg:120.78ms +step:2490/20000 train_loss:2.0531 train_time:300738ms step_avg:120.78ms +step:2500/20000 train_loss:2.0432 train_time:301933ms step_avg:120.77ms +step:2510/20000 train_loss:1.9904 train_time:303141ms step_avg:120.77ms +step:2520/20000 train_loss:2.1164 train_time:304335ms step_avg:120.77ms +step:2530/20000 train_loss:2.0198 train_time:305532ms step_avg:120.76ms +step:2540/20000 train_loss:2.0499 train_time:306734ms step_avg:120.76ms +step:2550/20000 train_loss:2.1298 train_time:307994ms step_avg:120.78ms +step:2560/20000 train_loss:2.1002 train_time:309191ms step_avg:120.78ms +step:2570/20000 train_loss:2.0456 train_time:310385ms step_avg:120.77ms +step:2580/20000 train_loss:2.1641 train_time:311580ms step_avg:120.77ms +step:2590/20000 train_loss:2.1015 train_time:312785ms step_avg:120.77ms +step:2600/20000 train_loss:2.1248 train_time:313979ms step_avg:120.76ms +step:2610/20000 train_loss:2.1477 train_time:315175ms step_avg:120.76ms +step:2620/20000 train_loss:2.0727 train_time:316381ms step_avg:120.76ms +step:2630/20000 train_loss:2.3665 train_time:317578ms step_avg:120.75ms +step:2640/20000 train_loss:2.0290 train_time:318787ms step_avg:120.75ms +step:2650/20000 train_loss:2.0423 train_time:319983ms step_avg:120.75ms +step:2660/20000 train_loss:2.0266 train_time:321189ms step_avg:120.75ms +step:2670/20000 train_loss:2.1560 train_time:322387ms step_avg:120.74ms +step:2680/20000 train_loss:1.9302 train_time:323654ms step_avg:120.77ms +step:2690/20000 train_loss:2.2089 train_time:324860ms step_avg:120.77ms +step:2700/20000 train_loss:2.0759 train_time:326054ms step_avg:120.76ms +step:2710/20000 train_loss:2.0536 train_time:327249ms step_avg:120.76ms +step:2720/20000 train_loss:2.0935 train_time:328451ms step_avg:120.75ms +step:2730/20000 train_loss:2.0412 train_time:329652ms step_avg:120.75ms +step:2740/20000 train_loss:2.0890 train_time:330856ms step_avg:120.75ms +step:2750/20000 train_loss:2.1152 train_time:332058ms step_avg:120.75ms +step:2760/20000 train_loss:2.0861 train_time:333263ms step_avg:120.75ms +step:2770/20000 train_loss:2.0097 train_time:334457ms step_avg:120.74ms +step:2780/20000 train_loss:2.3655 train_time:335659ms step_avg:120.74ms +step:2790/20000 train_loss:2.0154 train_time:336854ms step_avg:120.74ms +step:2800/20000 train_loss:2.0924 train_time:338121ms step_avg:120.76ms +step:2810/20000 train_loss:2.0220 train_time:339322ms step_avg:120.76ms +step:2820/20000 train_loss:2.1150 train_time:340523ms step_avg:120.75ms +step:2830/20000 train_loss:1.9728 train_time:341721ms step_avg:120.75ms +step:2840/20000 train_loss:2.0381 train_time:342917ms step_avg:120.75ms +step:2850/20000 train_loss:2.0900 train_time:344119ms step_avg:120.74ms +step:2860/20000 train_loss:2.0946 train_time:345315ms step_avg:120.74ms +step:2870/20000 train_loss:2.0380 train_time:346515ms step_avg:120.74ms +step:2880/20000 train_loss:2.0377 train_time:347720ms step_avg:120.74ms +step:2890/20000 train_loss:2.1454 train_time:348917ms step_avg:120.73ms +step:2900/20000 train_loss:2.0253 train_time:350131ms step_avg:120.73ms +step:2910/20000 train_loss:2.0836 train_time:351358ms step_avg:120.74ms +step:2920/20000 train_loss:2.0936 train_time:352585ms step_avg:120.75ms +step:2930/20000 train_loss:2.0827 train_time:353847ms step_avg:120.77ms +step:2940/20000 train_loss:1.8942 train_time:355040ms step_avg:120.76ms +step:2950/20000 train_loss:2.1026 train_time:356258ms step_avg:120.77ms +step:2960/20000 train_loss:2.0532 train_time:357455ms step_avg:120.76ms +step:2970/20000 train_loss:2.0006 train_time:358661ms step_avg:120.76ms +step:2980/20000 train_loss:2.0457 train_time:359857ms step_avg:120.76ms +step:2990/20000 train_loss:2.1263 train_time:361054ms step_avg:120.75ms +step:3000/20000 train_loss:2.1075 train_time:362261ms step_avg:120.75ms +step:3010/20000 train_loss:2.0638 train_time:363455ms step_avg:120.75ms +step:3020/20000 train_loss:2.0169 train_time:364661ms step_avg:120.75ms +step:3030/20000 train_loss:1.9519 train_time:365860ms step_avg:120.75ms +step:3040/20000 train_loss:2.0419 train_time:367063ms step_avg:120.74ms +step:3050/20000 train_loss:2.0399 train_time:368259ms step_avg:120.74ms +step:3060/20000 train_loss:2.0446 train_time:369521ms step_avg:120.76ms +step:3070/20000 train_loss:2.0022 train_time:370727ms step_avg:120.76ms +step:3080/20000 train_loss:2.0213 train_time:371923ms step_avg:120.75ms +step:3090/20000 train_loss:1.9809 train_time:373130ms step_avg:120.75ms +step:3100/20000 train_loss:2.0478 train_time:374329ms step_avg:120.75ms +step:3110/20000 train_loss:2.4472 train_time:375533ms step_avg:120.75ms +step:3120/20000 train_loss:2.1454 train_time:376731ms step_avg:120.75ms +step:3130/20000 train_loss:2.1483 train_time:377932ms step_avg:120.74ms +step:3140/20000 train_loss:1.9950 train_time:379189ms step_avg:120.76ms +step:3150/20000 train_loss:2.0819 train_time:380383ms step_avg:120.76ms +step:3160/20000 train_loss:1.8993 train_time:381588ms step_avg:120.76ms +step:3170/20000 train_loss:2.0292 train_time:382781ms step_avg:120.75ms +step:3180/20000 train_loss:2.0525 train_time:384053ms step_avg:120.77ms +step:3190/20000 train_loss:2.0493 train_time:385249ms step_avg:120.77ms +step:3200/20000 train_loss:1.8500 train_time:386456ms step_avg:120.77ms +step:3210/20000 train_loss:2.1784 train_time:387663ms step_avg:120.77ms +step:3220/20000 train_loss:2.2719 train_time:388859ms step_avg:120.76ms +step:3230/20000 train_loss:1.9973 train_time:390054ms step_avg:120.76ms +step:3240/20000 train_loss:1.9610 train_time:391260ms step_avg:120.76ms +step:3250/20000 train_loss:2.0616 train_time:392455ms step_avg:120.76ms +step:3260/20000 train_loss:1.9345 train_time:393661ms step_avg:120.75ms +step:3270/20000 train_loss:2.0981 train_time:394856ms step_avg:120.75ms +step:3280/20000 train_loss:2.0251 train_time:396057ms step_avg:120.75ms +step:3290/20000 train_loss:2.0168 train_time:397258ms step_avg:120.75ms +step:3300/20000 train_loss:1.9943 train_time:398455ms step_avg:120.74ms +step:3310/20000 train_loss:2.1960 train_time:399719ms step_avg:120.76ms +step:3320/20000 train_loss:2.0457 train_time:400915ms step_avg:120.76ms +step:3330/20000 train_loss:1.8609 train_time:402121ms step_avg:120.76ms +step:3340/20000 train_loss:2.0608 train_time:403319ms step_avg:120.75ms +step:3350/20000 train_loss:1.9928 train_time:404525ms step_avg:120.75ms +step:3360/20000 train_loss:1.9660 train_time:405723ms step_avg:120.75ms +step:3370/20000 train_loss:1.9977 train_time:406928ms step_avg:120.75ms +step:3380/20000 train_loss:2.1859 train_time:408122ms step_avg:120.75ms +step:3390/20000 train_loss:2.0402 train_time:409325ms step_avg:120.74ms +step:3400/20000 train_loss:2.0721 train_time:410529ms step_avg:120.74ms +step:3410/20000 train_loss:2.2227 train_time:411735ms step_avg:120.74ms +step:3420/20000 train_loss:1.8656 train_time:412931ms step_avg:120.74ms +step:3430/20000 train_loss:2.0699 train_time:414131ms step_avg:120.74ms +step:3440/20000 train_loss:2.0218 train_time:415393ms step_avg:120.75ms +step:3450/20000 train_loss:2.0654 train_time:416590ms step_avg:120.75ms +step:3460/20000 train_loss:1.9967 train_time:417794ms step_avg:120.75ms +step:3470/20000 train_loss:1.9987 train_time:418993ms step_avg:120.75ms +step:3480/20000 train_loss:2.0870 train_time:420189ms step_avg:120.74ms +step:3490/20000 train_loss:2.0321 train_time:421382ms step_avg:120.74ms +step:3500/20000 train_loss:2.0347 train_time:422588ms step_avg:120.74ms +step:3510/20000 train_loss:2.1216 train_time:423782ms step_avg:120.74ms +step:3520/20000 train_loss:1.9934 train_time:424977ms step_avg:120.73ms +step:3530/20000 train_loss:2.0692 train_time:426179ms step_avg:120.73ms +step:3540/20000 train_loss:1.9890 train_time:427378ms step_avg:120.73ms +step:3550/20000 train_loss:2.0349 train_time:428574ms step_avg:120.73ms +step:3560/20000 train_loss:2.0641 train_time:429780ms step_avg:120.72ms +step:3570/20000 train_loss:2.0646 train_time:431052ms step_avg:120.74ms +step:3580/20000 train_loss:2.0270 train_time:432248ms step_avg:120.74ms +step:3590/20000 train_loss:2.0566 train_time:433441ms step_avg:120.74ms +step:3600/20000 train_loss:2.0307 train_time:434641ms step_avg:120.73ms +step:3610/20000 train_loss:1.9836 train_time:435835ms step_avg:120.73ms +step:3620/20000 train_loss:2.1683 train_time:437031ms step_avg:120.73ms +step:3630/20000 train_loss:2.0673 train_time:438231ms step_avg:120.72ms +step:3640/20000 train_loss:2.0330 train_time:439435ms step_avg:120.72ms +step:3650/20000 train_loss:1.9754 train_time:440636ms step_avg:120.72ms +step:3660/20000 train_loss:1.9971 train_time:441832ms step_avg:120.72ms +step:3670/20000 train_loss:2.1167 train_time:443032ms step_avg:120.72ms +step:3680/20000 train_loss:1.9703 train_time:444235ms step_avg:120.72ms +step:3690/20000 train_loss:2.0019 train_time:445499ms step_avg:120.73ms +step:3700/20000 train_loss:1.9779 train_time:446696ms step_avg:120.73ms +step:3710/20000 train_loss:2.0258 train_time:447899ms step_avg:120.73ms +step:3720/20000 train_loss:2.0760 train_time:449102ms step_avg:120.73ms +step:3730/20000 train_loss:2.0500 train_time:450308ms step_avg:120.73ms +step:3740/20000 train_loss:2.0901 train_time:451502ms step_avg:120.72ms +step:3750/20000 train_loss:2.0907 train_time:452708ms step_avg:120.72ms +step:3760/20000 train_loss:2.0312 train_time:453912ms step_avg:120.72ms +step:3770/20000 train_loss:2.0882 train_time:455108ms step_avg:120.72ms +step:3780/20000 train_loss:1.9869 train_time:456313ms step_avg:120.72ms +step:3790/20000 train_loss:2.0314 train_time:457518ms step_avg:120.72ms +step:3800/20000 train_loss:2.0143 train_time:458722ms step_avg:120.72ms +step:3810/20000 train_loss:1.9614 train_time:459926ms step_avg:120.72ms +step:3820/20000 train_loss:2.1117 train_time:461181ms step_avg:120.73ms +step:3830/20000 train_loss:2.0134 train_time:462385ms step_avg:120.73ms +step:3840/20000 train_loss:1.8490 train_time:463586ms step_avg:120.73ms +step:3850/20000 train_loss:2.0751 train_time:464780ms step_avg:120.72ms +step:3860/20000 train_loss:2.0387 train_time:465981ms step_avg:120.72ms +step:3870/20000 train_loss:2.0122 train_time:467177ms step_avg:120.72ms +step:3880/20000 train_loss:2.0210 train_time:468373ms step_avg:120.71ms +step:3890/20000 train_loss:2.0237 train_time:469568ms step_avg:120.71ms +step:3900/20000 train_loss:2.0066 train_time:470761ms step_avg:120.71ms +step:3910/20000 train_loss:2.0599 train_time:471959ms step_avg:120.71ms +step:3920/20000 train_loss:1.9477 train_time:473157ms step_avg:120.70ms +step:3930/20000 train_loss:1.9670 train_time:474353ms step_avg:120.70ms +step:3940/20000 train_loss:1.9760 train_time:475557ms step_avg:120.70ms +step:3950/20000 train_loss:1.9936 train_time:476819ms step_avg:120.71ms +step:3960/20000 train_loss:2.0142 train_time:478013ms step_avg:120.71ms +step:3970/20000 train_loss:2.0750 train_time:479218ms step_avg:120.71ms +step:3980/20000 train_loss:2.0187 train_time:480414ms step_avg:120.71ms +step:3990/20000 train_loss:2.0352 train_time:481609ms step_avg:120.70ms +step:4000/20000 train_loss:1.9401 train_time:482812ms step_avg:120.70ms +step:4010/20000 train_loss:2.3012 train_time:484006ms step_avg:120.70ms +step:4020/20000 train_loss:1.9739 train_time:485200ms step_avg:120.70ms +step:4030/20000 train_loss:2.0471 train_time:486400ms step_avg:120.69ms +step:4040/20000 train_loss:1.9891 train_time:487594ms step_avg:120.69ms +step:4050/20000 train_loss:2.0526 train_time:488789ms step_avg:120.69ms +step:4060/20000 train_loss:2.0840 train_time:489993ms step_avg:120.69ms +step:4070/20000 train_loss:2.0265 train_time:491256ms step_avg:120.70ms +step:4080/20000 train_loss:1.9212 train_time:492456ms step_avg:120.70ms +step:4090/20000 train_loss:1.9882 train_time:493662ms step_avg:120.70ms +step:4100/20000 train_loss:1.9008 train_time:494857ms step_avg:120.70ms +step:4110/20000 train_loss:2.0764 train_time:496053ms step_avg:120.69ms +step:4120/20000 train_loss:1.8386 train_time:497247ms step_avg:120.69ms +step:4130/20000 train_loss:1.9578 train_time:498441ms step_avg:120.69ms +step:4140/20000 train_loss:1.8470 train_time:499636ms step_avg:120.69ms +step:4150/20000 train_loss:2.0601 train_time:500840ms step_avg:120.68ms +step:4160/20000 train_loss:2.0093 train_time:502035ms step_avg:120.68ms +step:4170/20000 train_loss:1.9424 train_time:503236ms step_avg:120.68ms +step:4180/20000 train_loss:2.0089 train_time:504432ms step_avg:120.68ms +step:4190/20000 train_loss:2.0830 train_time:505625ms step_avg:120.67ms +step:4200/20000 train_loss:2.0398 train_time:506889ms step_avg:120.69ms +step:4210/20000 train_loss:1.9866 train_time:508084ms step_avg:120.68ms +step:4220/20000 train_loss:1.9667 train_time:509278ms step_avg:120.68ms +step:4230/20000 train_loss:1.8234 train_time:510479ms step_avg:120.68ms +step:4240/20000 train_loss:2.0178 train_time:511676ms step_avg:120.68ms +step:4250/20000 train_loss:2.0141 train_time:512872ms step_avg:120.68ms +step:4260/20000 train_loss:2.0242 train_time:514078ms step_avg:120.68ms +step:4270/20000 train_loss:2.0041 train_time:515286ms step_avg:120.68ms +step:4280/20000 train_loss:2.0351 train_time:516495ms step_avg:120.68ms +step:4290/20000 train_loss:1.9462 train_time:517702ms step_avg:120.68ms +swa:start step:4300 +step:4300/20000 train_loss:1.9955 train_time:518898ms step_avg:120.67ms +step:4310/20000 train_loss:1.9575 train_time:520190ms step_avg:120.69ms +step:4320/20000 train_loss:2.1270 train_time:521390ms step_avg:120.69ms +step:4330/20000 train_loss:2.0155 train_time:522662ms step_avg:120.71ms +step:4340/20000 train_loss:2.0275 train_time:523868ms step_avg:120.71ms +step:4350/20000 train_loss:1.9136 train_time:525064ms step_avg:120.70ms +step:4360/20000 train_loss:1.9823 train_time:526300ms step_avg:120.71ms +step:4370/20000 train_loss:2.0220 train_time:527508ms step_avg:120.71ms +step:4380/20000 train_loss:1.8875 train_time:528704ms step_avg:120.71ms +step:4390/20000 train_loss:2.0774 train_time:529909ms step_avg:120.71ms +step:4400/20000 train_loss:2.0437 train_time:531105ms step_avg:120.71ms +step:4410/20000 train_loss:1.9209 train_time:532339ms step_avg:120.71ms +step:4420/20000 train_loss:1.9114 train_time:533536ms step_avg:120.71ms +step:4430/20000 train_loss:2.0377 train_time:534742ms step_avg:120.71ms +step:4440/20000 train_loss:2.0259 train_time:535944ms step_avg:120.71ms +late_qat:enabled step:4446 scale:0.1499 +step:4450/20000 train_loss:1.9236 train_time:537156ms step_avg:120.71ms +step:4460/20000 train_loss:1.9013 train_time:538462ms step_avg:120.73ms +step:4470/20000 train_loss:2.0934 train_time:539681ms step_avg:120.73ms +step:4480/20000 train_loss:1.8464 train_time:540891ms step_avg:120.73ms +step:4490/20000 train_loss:1.9002 train_time:542110ms step_avg:120.74ms +step:4500/20000 train_loss:1.9883 train_time:543319ms step_avg:120.74ms +step:4510/20000 train_loss:2.0415 train_time:544570ms step_avg:120.75ms +step:4520/20000 train_loss:1.9408 train_time:545779ms step_avg:120.75ms +step:4530/20000 train_loss:1.8627 train_time:546997ms step_avg:120.75ms +step:4540/20000 train_loss:1.9853 train_time:548213ms step_avg:120.75ms +step:4550/20000 train_loss:1.8948 train_time:549433ms step_avg:120.75ms +step:4560/20000 train_loss:1.9165 train_time:550679ms step_avg:120.76ms +step:4570/20000 train_loss:1.9725 train_time:551888ms step_avg:120.76ms +step:4580/20000 train_loss:2.0280 train_time:553164ms step_avg:120.78ms +step:4590/20000 train_loss:1.9778 train_time:554384ms step_avg:120.78ms +step:4600/20000 train_loss:1.9294 train_time:555593ms step_avg:120.78ms +step:4610/20000 train_loss:2.0455 train_time:556834ms step_avg:120.79ms +step:4620/20000 train_loss:1.9799 train_time:558044ms step_avg:120.79ms +step:4630/20000 train_loss:2.0138 train_time:559255ms step_avg:120.79ms +step:4640/20000 train_loss:1.8814 train_time:560463ms step_avg:120.79ms +step:4650/20000 train_loss:2.0249 train_time:561673ms step_avg:120.79ms +step:4660/20000 train_loss:2.0011 train_time:562927ms step_avg:120.80ms +step:4670/20000 train_loss:1.8810 train_time:564136ms step_avg:120.80ms +step:4680/20000 train_loss:1.9926 train_time:565365ms step_avg:120.80ms +step:4690/20000 train_loss:2.0220 train_time:566575ms step_avg:120.80ms +step:4700/20000 train_loss:1.9570 train_time:567793ms step_avg:120.81ms +step:4710/20000 train_loss:1.8901 train_time:569117ms step_avg:120.83ms +step:4720/20000 train_loss:1.9250 train_time:570336ms step_avg:120.83ms +step:4730/20000 train_loss:1.9111 train_time:571556ms step_avg:120.84ms +step:4740/20000 train_loss:2.0232 train_time:572765ms step_avg:120.84ms +step:4750/20000 train_loss:1.9736 train_time:573986ms step_avg:120.84ms +step:4760/20000 train_loss:1.9235 train_time:575228ms step_avg:120.85ms +step:4770/20000 train_loss:1.7661 train_time:576446ms step_avg:120.85ms +step:4780/20000 train_loss:1.9751 train_time:577660ms step_avg:120.85ms +step:4790/20000 train_loss:1.9310 train_time:578880ms step_avg:120.85ms +step:4800/20000 train_loss:1.9972 train_time:580097ms step_avg:120.85ms +step:4810/20000 train_loss:1.9871 train_time:581356ms step_avg:120.86ms +step:4820/20000 train_loss:2.0182 train_time:582567ms step_avg:120.86ms +step:4830/20000 train_loss:1.9717 train_time:583778ms step_avg:120.86ms +step:4840/20000 train_loss:1.9956 train_time:585056ms step_avg:120.88ms +step:4850/20000 train_loss:2.0103 train_time:586270ms step_avg:120.88ms +step:4860/20000 train_loss:1.8983 train_time:587531ms step_avg:120.89ms +step:4870/20000 train_loss:2.0072 train_time:588742ms step_avg:120.89ms +step:4880/20000 train_loss:1.8950 train_time:589962ms step_avg:120.89ms +step:4890/20000 train_loss:1.8808 train_time:591175ms step_avg:120.89ms +step:4900/20000 train_loss:1.9638 train_time:592384ms step_avg:120.89ms +step:4910/20000 train_loss:2.0156 train_time:593636ms step_avg:120.90ms +step:4920/20000 train_loss:1.9993 train_time:594856ms step_avg:120.91ms +step:4930/20000 train_loss:2.0534 train_time:596075ms step_avg:120.91ms +step:4940/20000 train_loss:1.9125 train_time:597285ms step_avg:120.91ms +step:4950/20000 train_loss:1.9280 train_time:598497ms step_avg:120.91ms +step:4960/20000 train_loss:2.0460 train_time:599827ms step_avg:120.93ms +step:4962/20000 val_loss:1.9725 val_bpb:1.1682 train_time:600071ms step_avg:120.93ms +stopping_early: wallclock_cap train_time:600071ms step:4962/20000 +peak memory allocated: 22059 MiB reserved: 22132 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9689 val_bpb:1.1661 eval_time:2144ms +EMA improved val_bpb (1.1682 -> 1.1661), keeping EMA weights for final serialization +Serialized model: 99922207 bytes +Code size: 87961 bytes +Total submission size: 100010168 bytes +Serialized model int6+zlib: 15841381 bytes +Total submission size int6+zlib: 15929342 bytes +Total submission size int8+zlib: 15929342 bytes +final_int6_zlib_roundtrip val_loss:1.9848 val_bpb:1.1755 eval_time:43630ms +final_int6_zlib_roundtrip_exact val_loss:1.98479406 val_bpb:1.17550684 +final_int6_sliding_window val_loss:1.9446 val_bpb:1.1517 stride:64 eval_time:102933ms +final_int6_sliding_window_exact val_loss:1.94455924 val_bpb:1.15168056 +final_int8_zlib_roundtrip_exact val_loss:1.94455924 val_bpb:1.15168056 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 +ttt_sliding:params unfrozen=25431141 frozen=0 + ttt_chunk [1/1893] bpb=1.186635 time=0.5s + ttt_chunk [11/1893] bpb=1.510736 time=3.0s + ttt_chunk [21/1893] bpb=1.491659 time=5.5s + ttt_chunk [31/1893] bpb=1.478977 time=8.0s + ttt_chunk [41/1893] bpb=1.460968 time=10.5s + ttt_chunk [51/1893] bpb=1.455660 time=13.0s + ttt_chunk [61/1893] bpb=1.462596 time=15.5s + ttt_chunk [71/1893] bpb=1.456687 time=17.9s + ttt_chunk [81/1893] bpb=1.456238 time=20.4s + ttt_chunk [91/1893] bpb=1.455966 time=22.9s + ttt_chunk [101/1893] bpb=1.458559 time=25.4s + ttt_chunk [111/1893] bpb=1.459783 time=27.9s + ttt_chunk [121/1893] bpb=1.450680 time=30.4s + ttt_chunk [131/1893] bpb=1.450762 time=32.9s + ttt_chunk [141/1893] bpb=1.456189 time=35.4s + ttt_chunk [151/1893] bpb=1.455988 time=37.9s + ttt_chunk [161/1893] bpb=1.455805 time=40.4s + ttt_chunk [171/1893] bpb=1.459179 time=42.9s + ttt_chunk [181/1893] bpb=1.461658 time=45.4s + ttt_chunk [191/1893] bpb=1.469154 time=47.9s + ttt_chunk [201/1893] bpb=1.467515 time=50.3s + ttt_chunk [211/1893] bpb=1.464178 time=52.8s + ttt_chunk [221/1893] bpb=1.464447 time=55.3s + ttt_chunk [231/1893] bpb=1.462888 time=57.8s + ttt_chunk [241/1893] bpb=1.462667 time=60.3s + ttt_chunk [251/1893] bpb=1.462432 time=62.7s + ttt_chunk [261/1893] bpb=1.457694 time=65.2s + ttt_chunk [271/1893] bpb=1.456021 time=67.7s + ttt_chunk [281/1893] bpb=1.456076 time=70.2s + ttt_chunk [291/1893] bpb=1.457298 time=72.7s + ttt_chunk [301/1893] bpb=1.456878 time=75.2s + ttt_chunk [311/1893] bpb=1.458146 time=77.7s + ttt_chunk [321/1893] bpb=1.458600 time=80.1s + ttt_chunk [331/1893] bpb=1.457343 time=82.6s + ttt_chunk [341/1893] bpb=1.455029 time=85.1s + ttt_chunk [351/1893] bpb=1.456486 time=87.6s + ttt_chunk [361/1893] bpb=1.456748 time=90.1s + ttt_chunk [371/1893] bpb=1.454680 time=92.5s + ttt_chunk [381/1893] bpb=1.454024 time=95.0s + ttt_chunk [391/1893] bpb=1.453026 time=97.5s + ttt_chunk [401/1893] bpb=1.449979 time=100.0s + ttt_chunk [411/1893] bpb=1.448228 time=102.4s + ttt_chunk [421/1893] bpb=1.446324 time=104.9s + ttt_chunk [431/1893] bpb=1.445412 time=107.4s + ttt_chunk [441/1893] bpb=1.444663 time=109.9s + ttt_chunk [451/1893] bpb=1.444041 time=112.3s + ttt_chunk [461/1893] bpb=1.441921 time=114.8s + ttt_chunk [471/1893] bpb=1.441280 time=117.3s + ttt_chunk [481/1893] bpb=1.440558 time=119.8s + ttt_chunk [491/1893] bpb=1.438342 time=122.2s + ttt_chunk [501/1893] bpb=1.437166 time=124.7s + ttt_chunk [511/1893] bpb=1.435976 time=127.2s + ttt_chunk [521/1893] bpb=1.432976 time=129.7s + ttt_chunk [531/1893] bpb=1.433320 time=132.2s + ttt_chunk [541/1893] bpb=1.432989 time=134.7s + ttt_chunk [551/1893] bpb=1.430960 time=137.2s + ttt_chunk [561/1893] bpb=1.430692 time=139.6s + ttt_chunk [571/1893] bpb=1.428738 time=142.1s + ttt_chunk [581/1893] bpb=1.427276 time=144.6s + ttt_chunk [591/1893] bpb=1.425885 time=147.1s + ttt_chunk [601/1893] bpb=1.425639 time=149.6s + ttt_chunk [611/1893] bpb=1.424930 time=152.1s + ttt_chunk [621/1893] bpb=1.424178 time=154.6s + ttt_chunk [631/1893] bpb=1.424111 time=157.0s + ttt_chunk [641/1893] bpb=1.423384 time=159.5s + ttt_chunk [651/1893] bpb=1.422759 time=162.0s + ttt_chunk [661/1893] bpb=1.421890 time=164.5s + ttt_chunk [671/1893] bpb=1.421777 time=166.9s + ttt_chunk [681/1893] bpb=1.421713 time=169.4s + ttt_chunk [691/1893] bpb=1.422271 time=171.9s + ttt_chunk [701/1893] bpb=1.421263 time=174.4s + ttt_chunk [711/1893] bpb=1.421055 time=176.8s + ttt_chunk [721/1893] bpb=1.420490 time=179.3s + ttt_chunk [731/1893] bpb=1.420375 time=181.8s + ttt_chunk [741/1893] bpb=1.420192 time=184.3s + ttt_chunk [751/1893] bpb=1.419433 time=186.7s + ttt_chunk [761/1893] bpb=1.418992 time=189.2s + ttt_chunk [771/1893] bpb=1.418390 time=191.7s + ttt_chunk [781/1893] bpb=1.419178 time=194.2s + ttt_chunk [791/1893] bpb=1.418560 time=196.6s + ttt_chunk [801/1893] bpb=1.418538 time=199.1s + ttt_chunk [811/1893] bpb=1.418283 time=201.6s + ttt_chunk [821/1893] bpb=1.417886 time=204.0s + ttt_chunk [831/1893] bpb=1.417504 time=206.5s + ttt_chunk [841/1893] bpb=1.416522 time=209.0s + ttt_chunk [851/1893] bpb=1.416160 time=211.5s + ttt_chunk [861/1893] bpb=1.415770 time=213.9s + ttt_chunk [871/1893] bpb=1.415759 time=216.4s + ttt_chunk [881/1893] bpb=1.415884 time=218.9s + ttt_chunk [891/1893] bpb=1.415395 time=221.4s + ttt_chunk [901/1893] bpb=1.414898 time=223.8s + ttt_chunk [911/1893] bpb=1.414846 time=226.3s + ttt_chunk [921/1893] bpb=1.415014 time=228.8s + ttt_chunk [931/1893] bpb=1.414961 time=231.2s + ttt_chunk [941/1893] bpb=1.414362 time=233.7s + ttt_chunk [951/1893] bpb=1.414658 time=236.2s + ttt_chunk [961/1893] bpb=1.414237 time=238.7s + ttt_chunk [971/1893] bpb=1.414942 time=241.1s + ttt_chunk [981/1893] bpb=1.414661 time=243.6s + ttt_chunk [991/1893] bpb=1.414343 time=246.1s + ttt_chunk [1001/1893] bpb=1.414014 time=248.6s + ttt_chunk [1011/1893] bpb=1.413565 time=251.0s + ttt_chunk [1021/1893] bpb=1.413605 time=253.5s + ttt_chunk [1031/1893] bpb=1.413654 time=256.0s + ttt_chunk [1041/1893] bpb=1.412909 time=258.4s + ttt_chunk [1051/1893] bpb=1.412255 time=260.9s + ttt_chunk [1061/1893] bpb=1.411957 time=263.4s + ttt_chunk [1071/1893] bpb=1.412437 time=265.9s + ttt_chunk [1081/1893] bpb=1.412390 time=268.3s + ttt_chunk [1091/1893] bpb=1.412759 time=270.8s + ttt_chunk [1101/1893] bpb=1.412413 time=273.3s + ttt_chunk [1111/1893] bpb=1.411873 time=275.7s + ttt_chunk [1121/1893] bpb=1.411391 time=278.2s + ttt_chunk [1131/1893] bpb=1.410964 time=280.7s + ttt_chunk [1141/1893] bpb=1.410398 time=283.1s + ttt_chunk [1151/1893] bpb=1.410071 time=285.6s + ttt_chunk [1161/1893] bpb=1.409416 time=288.2s + ttt_chunk [1171/1893] bpb=1.409367 time=290.6s + ttt_chunk [1181/1893] bpb=1.408218 time=293.1s + ttt_chunk [1191/1893] bpb=1.407861 time=295.6s + ttt_chunk [1201/1893] bpb=1.407953 time=298.1s + ttt_chunk [1211/1893] bpb=1.407140 time=300.5s + ttt_chunk [1221/1893] bpb=1.406522 time=303.0s + ttt_chunk [1231/1893] bpb=1.405836 time=305.5s + ttt_chunk [1241/1893] bpb=1.405150 time=307.9s + ttt_chunk [1251/1893] bpb=1.404232 time=310.4s + ttt_chunk [1261/1893] bpb=1.404016 time=312.8s + ttt_chunk [1271/1893] bpb=1.403364 time=315.3s + ttt_chunk [1281/1893] bpb=1.402838 time=317.8s + ttt_chunk [1291/1893] bpb=1.402381 time=320.2s + ttt_chunk [1301/1893] bpb=1.401478 time=322.7s + ttt_chunk [1311/1893] bpb=1.400772 time=325.1s + ttt_chunk [1321/1893] bpb=1.400138 time=327.6s + ttt_chunk [1331/1893] bpb=1.399756 time=330.1s + ttt_chunk [1341/1893] bpb=1.399360 time=332.6s + ttt_chunk [1351/1893] bpb=1.399159 time=335.0s + ttt_chunk [1361/1893] bpb=1.399008 time=337.5s + ttt_chunk [1371/1893] bpb=1.398654 time=340.0s + ttt_chunk [1381/1893] bpb=1.398677 time=342.4s + ttt_chunk [1391/1893] bpb=1.398041 time=344.9s + ttt_chunk [1401/1893] bpb=1.397921 time=347.4s + ttt_chunk [1411/1893] bpb=1.397866 time=349.8s + ttt_chunk [1421/1893] bpb=1.397919 time=352.3s + ttt_chunk [1431/1893] bpb=1.397577 time=354.7s + ttt_chunk [1441/1893] bpb=1.397960 time=357.2s + ttt_chunk [1451/1893] bpb=1.398109 time=359.6s + ttt_chunk [1461/1893] bpb=1.397509 time=362.1s + ttt_chunk [1471/1893] bpb=1.398487 time=364.6s + ttt_chunk [1481/1893] bpb=1.397902 time=367.0s + ttt_chunk [1491/1893] bpb=1.397633 time=369.5s + ttt_chunk [1501/1893] bpb=1.397541 time=372.0s + ttt_chunk [1511/1893] bpb=1.397444 time=374.4s + ttt_chunk [1521/1893] bpb=1.397294 time=376.9s + ttt_chunk [1531/1893] bpb=1.396689 time=379.4s + ttt_chunk [1541/1893] bpb=1.396322 time=381.8s + ttt_chunk [1551/1893] bpb=1.396568 time=384.3s + ttt_chunk [1561/1893] bpb=1.396498 time=386.8s + ttt_chunk [1571/1893] bpb=1.396152 time=389.3s + ttt_chunk [1581/1893] bpb=1.396181 time=391.7s + ttt_chunk [1591/1893] bpb=1.395829 time=394.2s + ttt_chunk [1601/1893] bpb=1.395859 time=396.7s + ttt_chunk [1611/1893] bpb=1.395580 time=399.2s + ttt_chunk [1621/1893] bpb=1.394943 time=401.7s + ttt_chunk [1631/1893] bpb=1.395133 time=404.1s + ttt_chunk [1641/1893] bpb=1.395024 time=406.6s + ttt_chunk [1651/1893] bpb=1.394785 time=409.1s + ttt_chunk [1661/1893] bpb=1.394526 time=411.5s + ttt_chunk [1671/1893] bpb=1.394887 time=414.0s + ttt_chunk [1681/1893] bpb=1.394905 time=416.5s + ttt_chunk [1691/1893] bpb=1.394447 time=418.9s + ttt_chunk [1701/1893] bpb=1.394409 time=421.4s + ttt_chunk [1711/1893] bpb=1.394158 time=423.8s + ttt_chunk [1721/1893] bpb=1.394007 time=426.3s + ttt_chunk [1731/1893] bpb=1.393693 time=428.7s + ttt_chunk [1741/1893] bpb=1.393402 time=431.2s + ttt_chunk [1751/1893] bpb=1.393004 time=433.7s + ttt_chunk [1761/1893] bpb=1.393047 time=436.1s + ttt_chunk [1771/1893] bpb=1.392803 time=438.6s + ttt_chunk [1781/1893] bpb=1.392723 time=441.1s + ttt_chunk [1791/1893] bpb=1.392058 time=443.5s + ttt_chunk [1801/1893] bpb=1.391805 time=446.0s + ttt_chunk [1811/1893] bpb=1.391550 time=448.4s + ttt_chunk [1821/1893] bpb=1.391412 time=450.9s + ttt_chunk [1831/1893] bpb=1.390547 time=453.4s + ttt_chunk [1841/1893] bpb=1.390480 time=455.9s + ttt_chunk [1851/1893] bpb=1.390157 time=458.3s + ttt_chunk [1861/1893] bpb=1.389574 time=460.8s + ttt_chunk [1871/1893] bpb=1.389364 time=463.3s + ttt_chunk [1881/1893] bpb=1.388742 time=465.8s + ttt_chunk [1891/1893] bpb=1.388359 time=468.2s + ttt_chunk [1893/1893] bpb=1.388431 time=468.5s +ttt_sliding:done val_loss=2.342569 val_bpb=1.387405 elapsed=468.6s +legal_ttt val_loss:2.3426 val_bpb:1.3874 eval_time:469104ms +legal_ttt_exact val_loss:2.34256927 val_bpb:1.38740515 diff --git a/records/track_10min_16mb/2026-04-01_random_mlp/baseline_no_moe_xsa4_ve3_12l_5r7l_seed2026.txt b/records/track_10min_16mb/2026-04-01_random_mlp/baseline_no_moe_xsa4_ve3_12l_5r7l_seed2026.txt new file mode 100644 index 0000000000..22b71c5442 --- /dev/null +++ b/records/track_10min_16mb/2026-04-01_random_mlp/baseline_no_moe_xsa4_ve3_12l_5r7l_seed2026.txt @@ -0,0 +1,2366 @@ +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +from flash_attn_interface import flash_attn_func +# from flash_attn import flash_attn_func + +# make dynamo less complainy +import torch._dynamo +torch._dynamo.config.cache_size_limit = 64 + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # random projection control + rand_proj_layers = [int(x) for x in os.environ.get("RAND_PROJ_LAYERS", "").split(",") if x] + rand_gain = bool(int(os.environ.get("RAND_GAIN", "0"))) + mini_moe_experts = int(os.environ.get("MINI_MOE_EXPERTS", 1)) + rand_init_qr = bool(int(os.environ.get("RAND_INIT_QR", "1"))) + + # xsa control + xsa_layers = [int(x) for x in os.environ.get("XSA_LAYERS", "").split(",") if x] + + # quant + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + + # partial rope + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + + # ln scale + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + + # bigram control + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # value embedding control + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = [int(x) for x in os.environ.get("VE_LAYERS", "9,10").split(",") if x] + + # ttt + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + + # swa + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# --- sliding window eval --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # workaround for caching issues with sin/cos in rotary layers: + for block in base_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + +# --- int6 quant as per #414 --- +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + # TODO: not all mlp params likely want int6, e.g. scales and gates + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6( + state_dict: dict[str, Tensor], + int6_cats: set[str] +): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6( + result: dict[str, Tensor], + meta: dict[str, object], + template_sd: dict[str, Tensor] +) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # straight-through qat as per #137 + qat_ste: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self.qat_ste and self.training: + # STE fake int6: quantize to [-32,31] and dequantize, gradient flows through + w32 = w.float() + row_max = w32.abs().amax(dim=1).clamp_min(1e-8) + scale = row_max / 31.0 + w_q = torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None] + w = w + (w_q.to(w.dtype) - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +# using partial RoPE as per #414 +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + use_xsa: bool = False, + train_seq_len: int = 1024, + rope_dims: int | None = None + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, rope_dims=rope_dims) + self.use_xsa = use_xsa + + # efficient XSA as per + # https://github.com/unnir/parameter-golf/blob/a81f85bd7f632e3e48ef6b1da0017b81d25998a7/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/train_gpt.py + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + # Reshape y into KV head groups — free view, no memory alloc + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + # Project out self-value component per KV head group + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rotary.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rotary.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_func(q, k, v, causal=True) + + # efficient xsa as per https://github.com/openai/parameter-golf/pull/265 + if self.use_xsa: + y = self._xsa_efficient(y, v) + + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# stolen from #414 +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +# stolen from #414 +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +# stolen from #414 +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__( + self, + dim: int, + mlp_mult: int, + rng_up: bool = False, + use_rng_gain: bool = False, + mini_moe_experts: int = 1, + ): + super().__init__() + self.hidden = mlp_mult * dim + self.mini_moe_experts = mini_moe_experts + self.use_rng_gain = use_rng_gain + self.fc: CastedLinear | None = None + self.rand_gain: nn.Parameter | None = None + self.moe_router: CastedLinear | None = None + if not rng_up: + self.fc = CastedLinear(dim, self.hidden, bias=False) + else: + # non-persistent buffer, because deterministic random init, kept in bf16 because not quant'd anyway + self.register_buffer("fc_w", torch.empty((mini_moe_experts, self.hidden, dim), dtype=torch.bfloat16), persistent=False) + if self.use_rng_gain: + self.rand_gain = nn.Parameter(torch.ones(mini_moe_experts, self.hidden, dtype=torch.float32)) if rng_up else None + if self.mini_moe_experts > 1: + self.moe_router = CastedLinear(dim, mini_moe_experts, bias=False) + self.moe_router._zero_init = True + self.proj = CastedLinear(self.hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.fc is not None: + x = self.fc(x) + else: + if self.mini_moe_experts == 1: + x = F.linear(x, self.fc_w[0].to(dtype=x.dtype)) + if self.rand_gain is not None: + x = x * self.rand_gain[0].to(dtype=x.dtype)[None, :] + else: + # compute router + moe_weights = F.softmax(self.moe_router(x), dim=-1) + # compute individual expert outputs and weighted sum + # x: (bsz, seqlen, dim), fc_w: (mini_moe_experts, dim, hidden) + exout = torch.einsum("bsd,ehd->bseh", x, self.fc_w.to(dtype=x.dtype)) # (bsz, seqlen, mini_moe_experts, hidden) + if self.rand_gain is not None: + exout = exout * self.rand_gain.to(dtype=x.dtype)[None, None, :, :] + x = moe_weights.unsqueeze(-1) * exout + x = x.sum(dim=2) # (bsz, seqlen, hidden) + # x = torch.relu(x) + x = F.leaky_relu(x, negative_slope=0.5) + return self.proj(x.square()) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rng_up: bool = False, + use_rng_gain: bool = False, + mini_moe_experts: int = 1, + train_seq_len: int = 1024, + use_xsa: bool = False, + rope_dims: int | None = None, + ln_scale: bool = False, + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_xsa, train_seq_len, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult, rng_up, use_rng_gain, mini_moe_experts) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x) * self.ln_scale_factor, v_embed=v_embed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * self.ln_scale_factor) + return x + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + rand_proj_layers: list[int], + use_rng_gain: bool = False, + mini_moe_experts: int = 1, + rand_init_seed: int = 42, + rand_init_qr: bool = False, + xsa_layers: list[int] = [], + rope_dims: int | None = None, + ln_scale: bool = False, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + ve_dim: int = 128, + ve_layers: list[int] = [], + train_seq_len: int = 1024, + ): + super().__init__() + self.rand_init_seed = rand_init_seed + self.rand_init_qr = rand_init_qr + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.smear = SmearGate(model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_layers = num_layers + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rng_up=(i in rand_proj_layers), + use_rng_gain=use_rng_gain, + mini_moe_experts=mini_moe_experts, + train_seq_len=train_seq_len, + use_xsa=(i in xsa_layers), + rope_dims=rope_dims, + ln_scale=ln_scale, + layer_idx=i, + ) + for i in range(num_layers) + ] + ) + + # value embeddings + self.ve_layer_indices = ve_layers + self.ve_target_dim = num_kv_heads * (model_dim // num_heads) + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, self.ve_target_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + gen = torch.Generator() + gen.manual_seed(self.rand_init_seed) + + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * self.num_layers)) + + if isinstance(module, MLP) and getattr(module, "fc_w", None) is not None: + # perform seeded random init for the random fc weights + n_experts, d_out, d_in = module.fc_w.shape + if self.rand_init_qr: + for e in range(n_experts): + G = torch.randn((d_out, d_in), generator=gen) + q, _ = torch.linalg.qr(G) + module.fc_w[e].copy_(q) / math.sqrt(d_in) + else: + nn.init.normal_(module.fc_w, mean=0.0, std=1.0/math.sqrt(d_in), generator=gen) + # module.fc_w.bernoulli_(0.5, generator=gen).mul_(2).sub_(1).mul_(1.0 / math.sqrt(d_in)) + + # value-embeddins as per #414 + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict[str, Tensor] | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + ve_cache: dict[str, Tensor] = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + rand_proj_layers=args.rand_proj_layers, + use_rng_gain=args.rand_gain, + mini_moe_experts=args.mini_moe_experts, + train_seq_len=args.train_seq_len, + xsa_layers=args.xsa_layers, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + # compiled_model = base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"num_layers:{args.num_layers} model_dim:{args.model_dim} num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads} mlp_mult:{args.mlp_mult}") + log0(f"rand_proj_layers:{args.rand_proj_layers} rand_gain:{args.rand_gain} mini_moe_experts:{args.mini_moe_experts}") + log0(f"bigram_vocab_size:{args.bigram_vocab_size} bigram_dim:{args.bigram_dim} ve_dim:{args.ve_dim} ve_layers:{args.ve_layers}") + log0(f"tie_embeddings:{args.tie_embeddings} tied_embed_init_std:{args.tied_embed_init_std} logit_softcap:{args.logit_softcap}") + log0(f"rope_base:{args.rope_base} qk_gain_init:{args.qk_gain_init} rope_dims:{args.rope_dims} xsa_layers:{args.xsa_layers} ln_scale:{args.ln_scale}") + log0(f"xsa_layers:{args.xsa_layers} ve_layers:{args.ve_layers} ve_dim:{args.ve_dim}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0(f"mlp_mode:{'rng_up' if args.rand_proj_layers else 'standard'} mlp_mult:{args.mlp_mult} mini_moe_experts:{args.mini_moe_experts} use_rng_gain:{args.rand_gain}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + # enable qat during warmup so we don't pay the compilation tax later + if warmup_step == 2: + CastedLinear.qat_ste = True + if warmup_step == 4: + CastedLinear.qat_ste = False + model.eval() + if warmup_step == 6: + CastedLinear.qat_ste = True + if warmup_step == 8: + model.train() + CastedLinear.qat_ste = False + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + # SWA as per #414 + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # EMA as per #414 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + + val_bpb = float("inf") + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # late QAT as per #414 + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear.qat_ste: + CastedLinear.qat_ste = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + # EMA update as per #414 + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA as per #414 + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply EMA weights (better than SWA alone per PR#401) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + + if diag_val_bpb > val_bpb: + log0( + f"EMA did not improve val_bpb ({val_bpb:.4f} -> {diag_val_bpb:.4f}), restoring pre-EMA weights for final serialization" + ) + base_model.load_state_dict(current_state, strict=True) + else: + log0( + f"EMA improved val_bpb ({val_bpb:.4f} -> {diag_val_bpb:.4f}), keeping EMA weights for final serialization" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # --- quant --- + sd_cpu = {k: v.cpu() for k, v in base_model.state_dict().items()} + # TODO: think a/b keeping routers in fp32 or higher? + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + # quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + rand_proj_layers=args.rand_proj_layers, + use_rng_gain=args.rand_gain, + mini_moe_experts=args.mini_moe_experts, + train_seq_len=args.train_seq_len, + xsa_layers=args.xsa_layers, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + compiled_eval, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int6_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # --- sliding window + TTT eval --- + + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + # disable QAT ste for TTT + CastedLinear.qat_ste = False + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Mon Mar 30 20:33:53 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 49C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 38C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 48C P0 127W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 50C P0 127W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 39C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 49C P0 133W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 37C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:25431141 +num_layers:12 model_dim:512 num_heads:8 num_kv_heads:4 mlp_mult:3 +rand_proj_layers:[0, 1, 2, 3, 4] rand_gain:True mini_moe_experts:1 +bigram_vocab_size:2048 bigram_dim:128 ve_dim:128 ve_layers:[9, 10, 11] +tie_embeddings:True tied_embed_init_std:0.005 logit_softcap:30.0 +rope_base:10000.0 qk_gain_init:1.5 rope_dims:16 xsa_layers:[8, 9, 10, 11] ln_scale:True +xsa_layers:[8, 9, 10, 11] ve_layers:[9, 10, 11] ve_dim:128 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +mlp_mode:rng_up mlp_mult:3 mini_moe_experts:1 use_rng_gain:True +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2026 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9330 train_time:240ms step_avg:239.59ms +step:2/20000 train_loss:8.6662 train_time:360ms step_avg:179.87ms +step:3/20000 train_loss:8.2115 train_time:479ms step_avg:159.75ms +step:4/20000 train_loss:7.2434 train_time:598ms step_avg:149.61ms +step:5/20000 train_loss:7.0199 train_time:718ms step_avg:143.51ms +step:6/20000 train_loss:6.8514 train_time:837ms step_avg:139.47ms +step:7/20000 train_loss:6.7069 train_time:956ms step_avg:136.61ms +step:8/20000 train_loss:6.6860 train_time:1076ms step_avg:134.46ms +step:9/20000 train_loss:6.4556 train_time:1195ms step_avg:132.81ms +step:10/20000 train_loss:6.1956 train_time:1314ms step_avg:131.45ms +step:100/20000 train_loss:3.2119 train_time:12088ms step_avg:120.88ms +step:200/20000 train_loss:2.3928 train_time:24132ms step_avg:120.66ms +step:300/20000 train_loss:2.5374 train_time:36166ms step_avg:120.55ms +step:400/20000 train_loss:2.4052 train_time:48184ms step_avg:120.46ms +step:500/20000 train_loss:2.3910 train_time:60166ms step_avg:120.33ms +step:600/20000 train_loss:2.3300 train_time:72220ms step_avg:120.37ms +step:700/20000 train_loss:2.3436 train_time:84276ms step_avg:120.39ms +step:800/20000 train_loss:2.2402 train_time:96329ms step_avg:120.41ms +step:900/20000 train_loss:2.1304 train_time:108390ms step_avg:120.43ms +step:1000/20000 train_loss:2.2766 train_time:120393ms step_avg:120.39ms +step:1100/20000 train_loss:2.3236 train_time:132443ms step_avg:120.40ms +step:1200/20000 train_loss:2.3585 train_time:144485ms step_avg:120.40ms +step:1300/20000 train_loss:2.3587 train_time:156548ms step_avg:120.42ms +step:1400/20000 train_loss:2.2218 train_time:168564ms step_avg:120.40ms +step:1500/20000 train_loss:2.1926 train_time:180535ms step_avg:120.36ms +step:1600/20000 train_loss:2.1969 train_time:192592ms step_avg:120.37ms +step:1700/20000 train_loss:2.1871 train_time:204670ms step_avg:120.39ms +step:1800/20000 train_loss:2.1334 train_time:216719ms step_avg:120.40ms +step:1900/20000 train_loss:2.1860 train_time:228696ms step_avg:120.37ms +step:2000/20000 train_loss:2.1603 train_time:240760ms step_avg:120.38ms +step:2100/20000 train_loss:2.0530 train_time:252858ms step_avg:120.41ms +step:2200/20000 train_loss:2.1563 train_time:264930ms step_avg:120.42ms +step:2300/20000 train_loss:2.0922 train_time:276977ms step_avg:120.42ms +step:2400/20000 train_loss:2.1519 train_time:288962ms step_avg:120.40ms +step:2500/20000 train_loss:2.0424 train_time:301008ms step_avg:120.40ms +step:2600/20000 train_loss:2.1272 train_time:313120ms step_avg:120.43ms +step:2700/20000 train_loss:2.0748 train_time:325176ms step_avg:120.44ms +step:2800/20000 train_loss:2.0922 train_time:337260ms step_avg:120.45ms +step:2900/20000 train_loss:2.0237 train_time:349249ms step_avg:120.43ms +step:3000/20000 train_loss:2.1100 train_time:361293ms step_avg:120.43ms +step:3100/20000 train_loss:2.0472 train_time:373323ms step_avg:120.43ms +step:3200/20000 train_loss:1.8523 train_time:385358ms step_avg:120.42ms +step:3300/20000 train_loss:1.9944 train_time:397319ms step_avg:120.40ms +step:3400/20000 train_loss:2.0715 train_time:409356ms step_avg:120.40ms +step:3500/20000 train_loss:2.0340 train_time:421406ms step_avg:120.40ms +step:3600/20000 train_loss:2.0315 train_time:433465ms step_avg:120.41ms +step:3700/20000 train_loss:1.9747 train_time:445528ms step_avg:120.41ms +step:3800/20000 train_loss:2.0188 train_time:457540ms step_avg:120.41ms +step:3900/20000 train_loss:2.0057 train_time:469591ms step_avg:120.41ms +step:4000/20000 train_loss:1.9374 train_time:481622ms step_avg:120.41ms +step:4100/20000 train_loss:1.8996 train_time:493673ms step_avg:120.41ms +step:4200/20000 train_loss:2.0376 train_time:505704ms step_avg:120.41ms +swa:start step:4300 +step:4300/20000 train_loss:1.9905 train_time:517677ms step_avg:120.39ms +step:4400/20000 train_loss:2.0445 train_time:529910ms step_avg:120.43ms +late_qat:enabled step:4457 scale:0.1497 +step:4500/20000 train_loss:1.9850 train_time:542085ms step_avg:120.46ms +step:4600/20000 train_loss:1.9301 train_time:554339ms step_avg:120.51ms +step:4700/20000 train_loss:1.9571 train_time:566543ms step_avg:120.54ms +step:4800/20000 train_loss:1.9957 train_time:578803ms step_avg:120.58ms +step:4900/20000 train_loss:1.9666 train_time:591061ms step_avg:120.62ms +step:4973/20000 val_loss:1.9715 val_bpb:1.1677 train_time:600053ms step_avg:120.66ms +stopping_early: wallclock_cap train_time:600053ms step:4973/20000 +peak memory allocated: 22052 MiB reserved: 22134 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9681 val_bpb:1.1656 eval_time:2142ms +EMA improved val_bpb (1.1677 -> 1.1656), keeping EMA weights for final serialization +Serialized model: 99922207 bytes +Code size: 87961 bytes +Total submission size: 100010168 bytes +Serialized model int6+zlib: 15891683 bytes +Total submission size int6+zlib: 15979644 bytes +Total submission size int8+zlib: 15979644 bytes +final_int6_zlib_roundtrip val_loss:1.9835 val_bpb:1.1747 eval_time:7370ms +final_int6_zlib_roundtrip_exact val_loss:1.98346448 val_bpb:1.17471939 +final_int6_sliding_window val_loss:1.9432 val_bpb:1.1509 stride:64 eval_time:81964ms +final_int6_sliding_window_exact val_loss:1.94323720 val_bpb:1.15089758 +final_int8_zlib_roundtrip_exact val_loss:1.94323720 val_bpb:1.15089758 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.0002 ttt_epochs=1 freeze_blocks=5 +ttt_sliding:params unfrozen=17548861 frozen=7882280 + ttt_chunk [1/1893] bpb=1.182291 time=0.4s + ttt_chunk [11/1893] bpb=1.536242 time=1.9s + ttt_chunk [21/1893] bpb=1.547969 time=3.5s + ttt_chunk [31/1893] bpb=1.549086 time=5.1s + ttt_chunk [41/1893] bpb=1.539530 time=6.7s + ttt_chunk [51/1893] bpb=1.540337 time=8.3s + ttt_chunk [61/1893] bpb=1.552123 time=9.9s + ttt_chunk [71/1893] bpb=1.549265 time=11.6s + ttt_chunk [81/1893] bpb=1.549080 time=13.2s + ttt_chunk [91/1893] bpb=1.551192 time=14.8s + ttt_chunk [101/1893] bpb=1.555231 time=16.4s + ttt_chunk [111/1893] bpb=1.557703 time=18.0s + ttt_chunk [121/1893] bpb=1.548644 time=19.6s + ttt_chunk [131/1893] bpb=1.549539 time=21.3s + ttt_chunk [141/1893] bpb=1.555257 time=22.9s + ttt_chunk [151/1893] bpb=1.555049 time=24.5s + ttt_chunk [161/1893] bpb=1.554721 time=26.1s + ttt_chunk [171/1893] bpb=1.557701 time=27.7s + ttt_chunk [181/1893] bpb=1.560871 time=29.3s + ttt_chunk [191/1893] bpb=1.567501 time=30.9s + ttt_chunk [201/1893] bpb=1.565622 time=32.5s + ttt_chunk [211/1893] bpb=1.562531 time=34.1s + ttt_chunk [221/1893] bpb=1.562729 time=35.7s + ttt_chunk [231/1893] bpb=1.560938 time=37.3s + ttt_chunk [241/1893] bpb=1.560797 time=38.9s + ttt_chunk [251/1893] bpb=1.560173 time=40.5s + ttt_chunk [261/1893] bpb=1.555366 time=42.1s + ttt_chunk [271/1893] bpb=1.553500 time=43.7s + ttt_chunk [281/1893] bpb=1.552997 time=45.3s + ttt_chunk [291/1893] bpb=1.553969 time=46.9s + ttt_chunk [301/1893] bpb=1.553312 time=48.5s + ttt_chunk [311/1893] bpb=1.554407 time=50.1s + ttt_chunk [321/1893] bpb=1.554637 time=51.7s + ttt_chunk [331/1893] bpb=1.552987 time=53.3s + ttt_chunk [341/1893] bpb=1.550549 time=54.9s + ttt_chunk [351/1893] bpb=1.551897 time=56.5s + ttt_chunk [361/1893] bpb=1.551867 time=58.2s + ttt_chunk [371/1893] bpb=1.549514 time=59.8s + ttt_chunk [381/1893] bpb=1.548510 time=61.4s + ttt_chunk [391/1893] bpb=1.547422 time=63.0s + ttt_chunk [401/1893] bpb=1.544236 time=64.6s + ttt_chunk [411/1893] bpb=1.542261 time=66.2s + ttt_chunk [421/1893] bpb=1.540202 time=67.8s + ttt_chunk [431/1893] bpb=1.539180 time=69.4s + ttt_chunk [441/1893] bpb=1.538316 time=71.0s + ttt_chunk [451/1893] bpb=1.537601 time=72.6s + ttt_chunk [461/1893] bpb=1.535286 time=74.2s + ttt_chunk [471/1893] bpb=1.534506 time=75.8s + ttt_chunk [481/1893] bpb=1.533686 time=77.4s + ttt_chunk [491/1893] bpb=1.531321 time=79.0s + ttt_chunk [501/1893] bpb=1.530120 time=80.6s + ttt_chunk [511/1893] bpb=1.528726 time=82.2s + ttt_chunk [521/1893] bpb=1.525526 time=83.9s + ttt_chunk [531/1893] bpb=1.525756 time=85.5s + ttt_chunk [541/1893] bpb=1.525293 time=87.1s + ttt_chunk [551/1893] bpb=1.523244 time=88.7s + ttt_chunk [561/1893] bpb=1.522783 time=90.3s + ttt_chunk [571/1893] bpb=1.520734 time=91.9s + ttt_chunk [581/1893] bpb=1.519001 time=93.5s + ttt_chunk [591/1893] bpb=1.517419 time=95.1s + ttt_chunk [601/1893] bpb=1.517000 time=96.7s + ttt_chunk [611/1893] bpb=1.516181 time=98.3s + ttt_chunk [621/1893] bpb=1.515314 time=99.9s + ttt_chunk [631/1893] bpb=1.515219 time=101.5s + ttt_chunk [641/1893] bpb=1.514381 time=103.1s + ttt_chunk [651/1893] bpb=1.513585 time=104.7s + ttt_chunk [661/1893] bpb=1.512481 time=106.3s + ttt_chunk [671/1893] bpb=1.512230 time=107.9s + ttt_chunk [681/1893] bpb=1.512048 time=109.6s + ttt_chunk [691/1893] bpb=1.512510 time=111.2s + ttt_chunk [701/1893] bpb=1.511547 time=112.8s + ttt_chunk [711/1893] bpb=1.511315 time=114.4s + ttt_chunk [721/1893] bpb=1.510723 time=116.0s + ttt_chunk [731/1893] bpb=1.510614 time=117.6s + ttt_chunk [741/1893] bpb=1.510462 time=119.2s + ttt_chunk [751/1893] bpb=1.509628 time=120.8s + ttt_chunk [761/1893] bpb=1.509150 time=122.4s + ttt_chunk [771/1893] bpb=1.508553 time=124.0s + ttt_chunk [781/1893] bpb=1.509482 time=125.6s + ttt_chunk [791/1893] bpb=1.508839 time=127.2s + ttt_chunk [801/1893] bpb=1.508692 time=128.8s + ttt_chunk [811/1893] bpb=1.508513 time=130.4s + ttt_chunk [821/1893] bpb=1.508230 time=132.0s + ttt_chunk [831/1893] bpb=1.507899 time=133.6s + ttt_chunk [841/1893] bpb=1.506979 time=135.3s + ttt_chunk [851/1893] bpb=1.506716 time=136.9s + ttt_chunk [861/1893] bpb=1.506377 time=138.5s + ttt_chunk [871/1893] bpb=1.506417 time=140.1s + ttt_chunk [881/1893] bpb=1.506610 time=141.7s + ttt_chunk [891/1893] bpb=1.506043 time=143.3s + ttt_chunk [901/1893] bpb=1.505560 time=144.9s + ttt_chunk [911/1893] bpb=1.505561 time=146.5s + ttt_chunk [921/1893] bpb=1.505699 time=148.1s + ttt_chunk [931/1893] bpb=1.505666 time=149.7s + ttt_chunk [941/1893] bpb=1.505081 time=151.3s + ttt_chunk [951/1893] bpb=1.505466 time=152.9s + ttt_chunk [961/1893] bpb=1.505029 time=154.5s + ttt_chunk [971/1893] bpb=1.505748 time=156.1s + ttt_chunk [981/1893] bpb=1.505532 time=157.7s + ttt_chunk [991/1893] bpb=1.505225 time=159.4s + ttt_chunk [1001/1893] bpb=1.504936 time=161.0s + ttt_chunk [1011/1893] bpb=1.504545 time=162.6s + ttt_chunk [1021/1893] bpb=1.504636 time=164.2s + ttt_chunk [1031/1893] bpb=1.504662 time=165.8s + ttt_chunk [1041/1893] bpb=1.503947 time=167.4s + ttt_chunk [1051/1893] bpb=1.503324 time=169.0s + ttt_chunk [1061/1893] bpb=1.503006 time=170.6s + ttt_chunk [1071/1893] bpb=1.503482 time=172.2s + ttt_chunk [1081/1893] bpb=1.503362 time=173.8s + ttt_chunk [1091/1893] bpb=1.503634 time=175.4s + ttt_chunk [1101/1893] bpb=1.503354 time=177.0s + ttt_chunk [1111/1893] bpb=1.502759 time=178.6s + ttt_chunk [1121/1893] bpb=1.502257 time=180.2s + ttt_chunk [1131/1893] bpb=1.501787 time=181.8s + ttt_chunk [1141/1893] bpb=1.501264 time=183.4s + ttt_chunk [1151/1893] bpb=1.500961 time=185.0s + ttt_chunk [1161/1893] bpb=1.500287 time=186.6s + ttt_chunk [1171/1893] bpb=1.500188 time=188.3s + ttt_chunk [1181/1893] bpb=1.499028 time=189.9s + ttt_chunk [1191/1893] bpb=1.498700 time=191.5s + ttt_chunk [1201/1893] bpb=1.498755 time=193.1s + ttt_chunk [1211/1893] bpb=1.497950 time=194.7s + ttt_chunk [1221/1893] bpb=1.497290 time=196.3s + ttt_chunk [1231/1893] bpb=1.496611 time=197.9s + ttt_chunk [1241/1893] bpb=1.495955 time=199.6s + ttt_chunk [1251/1893] bpb=1.495043 time=201.2s + ttt_chunk [1261/1893] bpb=1.494812 time=202.8s + ttt_chunk [1271/1893] bpb=1.494172 time=204.4s + ttt_chunk [1281/1893] bpb=1.493554 time=206.0s + ttt_chunk [1291/1893] bpb=1.493135 time=207.6s + ttt_chunk [1301/1893] bpb=1.492224 time=209.2s + ttt_chunk [1311/1893] bpb=1.491491 time=210.8s + ttt_chunk [1321/1893] bpb=1.490865 time=212.4s + ttt_chunk [1331/1893] bpb=1.490465 time=214.0s + ttt_chunk [1341/1893] bpb=1.490089 time=215.6s + ttt_chunk [1351/1893] bpb=1.489921 time=217.2s + ttt_chunk [1361/1893] bpb=1.489811 time=218.8s + ttt_chunk [1371/1893] bpb=1.489484 time=220.4s + ttt_chunk [1381/1893] bpb=1.489504 time=222.0s + ttt_chunk [1391/1893] bpb=1.488838 time=223.6s + ttt_chunk [1401/1893] bpb=1.488771 time=225.2s + ttt_chunk [1411/1893] bpb=1.488761 time=226.8s + ttt_chunk [1421/1893] bpb=1.488849 time=228.4s + ttt_chunk [1431/1893] bpb=1.488647 time=230.0s + ttt_chunk [1441/1893] bpb=1.489077 time=231.6s + ttt_chunk [1451/1893] bpb=1.489271 time=233.2s + ttt_chunk [1461/1893] bpb=1.488752 time=234.8s + ttt_chunk [1471/1893] bpb=1.489740 time=236.4s + ttt_chunk [1481/1893] bpb=1.489209 time=238.0s + ttt_chunk [1491/1893] bpb=1.488977 time=239.6s + ttt_chunk [1501/1893] bpb=1.488910 time=241.2s + ttt_chunk [1511/1893] bpb=1.488854 time=242.8s + ttt_chunk [1521/1893] bpb=1.488748 time=244.4s + ttt_chunk [1531/1893] bpb=1.488163 time=246.0s + ttt_chunk [1541/1893] bpb=1.487831 time=247.6s + ttt_chunk [1551/1893] bpb=1.488139 time=249.2s + ttt_chunk [1561/1893] bpb=1.488163 time=250.8s + ttt_chunk [1571/1893] bpb=1.487846 time=252.4s + ttt_chunk [1581/1893] bpb=1.487871 time=254.0s + ttt_chunk [1591/1893] bpb=1.487548 time=255.6s + ttt_chunk [1601/1893] bpb=1.487614 time=257.2s + ttt_chunk [1611/1893] bpb=1.487335 time=258.8s + ttt_chunk [1621/1893] bpb=1.486692 time=260.4s + ttt_chunk [1631/1893] bpb=1.486896 time=262.0s + ttt_chunk [1641/1893] bpb=1.486831 time=263.7s + ttt_chunk [1651/1893] bpb=1.486657 time=265.3s + ttt_chunk [1661/1893] bpb=1.486404 time=266.9s + ttt_chunk [1671/1893] bpb=1.486747 time=268.5s + ttt_chunk [1681/1893] bpb=1.486781 time=270.1s + ttt_chunk [1691/1893] bpb=1.486317 time=271.7s + ttt_chunk [1701/1893] bpb=1.486259 time=273.3s + ttt_chunk [1711/1893] bpb=1.486010 time=274.9s + ttt_chunk [1721/1893] bpb=1.485854 time=276.5s + ttt_chunk [1731/1893] bpb=1.485563 time=278.1s + ttt_chunk [1741/1893] bpb=1.485335 time=279.7s + ttt_chunk [1751/1893] bpb=1.484979 time=281.3s + ttt_chunk [1761/1893] bpb=1.485059 time=282.9s + ttt_chunk [1771/1893] bpb=1.484817 time=284.5s + ttt_chunk [1781/1893] bpb=1.484768 time=286.0s + ttt_chunk [1791/1893] bpb=1.484081 time=287.7s + ttt_chunk [1801/1893] bpb=1.483855 time=289.2s + ttt_chunk [1811/1893] bpb=1.483649 time=290.8s + ttt_chunk [1821/1893] bpb=1.483512 time=292.4s + ttt_chunk [1831/1893] bpb=1.482663 time=294.0s + ttt_chunk [1841/1893] bpb=1.482573 time=295.6s + ttt_chunk [1851/1893] bpb=1.482265 time=297.2s + ttt_chunk [1861/1893] bpb=1.481645 time=298.8s + ttt_chunk [1871/1893] bpb=1.481437 time=300.4s + ttt_chunk [1881/1893] bpb=1.480811 time=302.0s + ttt_chunk [1891/1893] bpb=1.480443 time=303.6s + ttt_chunk [1893/1893] bpb=1.480523 time=303.9s +ttt_sliding:done val_loss=2.497682 val_bpb=1.479272 elapsed=303.9s +legal_ttt val_loss:2.4977 val_bpb:1.4793 eval_time:304429ms +legal_ttt_exact val_loss:2.49768189 val_bpb:1.47927182 diff --git a/records/track_10min_16mb/2026-04-01_random_mlp/baseline_no_moe_xsa4_ve3_12l_5r7l_seed999.txt b/records/track_10min_16mb/2026-04-01_random_mlp/baseline_no_moe_xsa4_ve3_12l_5r7l_seed999.txt new file mode 100644 index 0000000000..58c6c2d1e5 --- /dev/null +++ b/records/track_10min_16mb/2026-04-01_random_mlp/baseline_no_moe_xsa4_ve3_12l_5r7l_seed999.txt @@ -0,0 +1,2605 @@ +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +from flash_attn_interface import flash_attn_func +# from flash_attn import flash_attn_func + +# make dynamo less complainy +import torch._dynamo +torch._dynamo.config.cache_size_limit = 64 + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # random projection control + rand_proj_layers = [int(x) for x in os.environ.get("RAND_PROJ_LAYERS", "").split(",") if x] + rand_gain = bool(int(os.environ.get("RAND_GAIN", "0"))) + mini_moe_experts = int(os.environ.get("MINI_MOE_EXPERTS", 1)) + rand_init_qr = bool(int(os.environ.get("RAND_INIT_QR", "1"))) + + # xsa control + xsa_layers = [int(x) for x in os.environ.get("XSA_LAYERS", "").split(",") if x] + + # quant + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + + # partial rope + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + + # ln scale + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + + # bigram control + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # value embedding control + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = [int(x) for x in os.environ.get("VE_LAYERS", "9,10").split(",") if x] + + # ttt + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + + # swa + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# --- sliding window eval --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + # ttt_params = [] + # for name, p in base_model.named_parameters(): + # freeze = False + # for bi in frozen_block_ids: + # if f"blocks.{bi}." in name: + # freeze = True + # break + # if freeze: + # p.requires_grad_(False) + # else: + # p.requires_grad_(True) + # ttt_params.append(p) + ttt_params = [ + p for name, p in base_model.named_parameters() + if p.ndim == 2 and not any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS) + and not any(f"blocks.{bi}." in name for bi in frozen_block_ids)] + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + # optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + optimizer = Muon(ttt_params, lr=args.ttt_lr, momentum=0.0, backend_steps=5, weight_decay=0.0) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # workaround for caching issues with sin/cos in rotary layers: + for block in base_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + +# --- int6 quant as per #414 --- +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + # TODO: not all mlp params likely want int6, e.g. scales and gates + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6( + state_dict: dict[str, Tensor], + int6_cats: set[str] +): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6( + result: dict[str, Tensor], + meta: dict[str, object], + template_sd: dict[str, Tensor] +) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # straight-through qat as per #137 + qat_ste: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self.qat_ste and self.training: + # STE fake int6: quantize to [-32,31] and dequantize, gradient flows through + w32 = w.float() + row_max = w32.abs().amax(dim=1).clamp_min(1e-8) + scale = row_max / 31.0 + w_q = torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None] + w = w + (w_q.to(w.dtype) - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +# using partial RoPE as per #414 +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + use_xsa: bool = False, + train_seq_len: int = 1024, + rope_dims: int | None = None + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, rope_dims=rope_dims) + self.use_xsa = use_xsa + + # efficient XSA as per + # https://github.com/unnir/parameter-golf/blob/a81f85bd7f632e3e48ef6b1da0017b81d25998a7/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/train_gpt.py + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + # Reshape y into KV head groups — free view, no memory alloc + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + # Project out self-value component per KV head group + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rotary.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rotary.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_func(q, k, v, causal=True) + + # efficient xsa as per https://github.com/openai/parameter-golf/pull/265 + if self.use_xsa: + y = self._xsa_efficient(y, v) + + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# stolen from #414 +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +# stolen from #414 +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +# stolen from #414 +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__( + self, + dim: int, + mlp_mult: int, + rng_up: bool = False, + use_rng_gain: bool = False, + mini_moe_experts: int = 1, + ): + super().__init__() + self.hidden = mlp_mult * dim + self.mini_moe_experts = mini_moe_experts + self.use_rng_gain = use_rng_gain + self.fc: CastedLinear | None = None + self.rand_gain: nn.Parameter | None = None + self.moe_router: CastedLinear | None = None + if not rng_up: + self.fc = CastedLinear(dim, self.hidden, bias=False) + else: + # non-persistent buffer, because deterministic random init, kept in bf16 because not quant'd anyway + self.register_buffer("fc_w", torch.empty((mini_moe_experts, self.hidden, dim), dtype=torch.bfloat16), persistent=False) + if self.use_rng_gain: + self.rand_gain = nn.Parameter(torch.ones(mini_moe_experts, self.hidden, dtype=torch.float32)) if rng_up else None + if self.mini_moe_experts > 1: + self.moe_router = CastedLinear(dim, mini_moe_experts, bias=False) + self.moe_router._zero_init = True + self.proj = CastedLinear(self.hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.fc is not None: + x = self.fc(x) + else: + if self.mini_moe_experts == 1: + x = F.linear(x, self.fc_w[0].to(dtype=x.dtype)) + if self.rand_gain is not None: + x = x * self.rand_gain[0].to(dtype=x.dtype)[None, :] + else: + # compute router + moe_weights = F.softmax(self.moe_router(x), dim=-1) + # compute individual expert outputs and weighted sum + # x: (bsz, seqlen, dim), fc_w: (mini_moe_experts, dim, hidden) + exout = torch.einsum("bsd,ehd->bseh", x, self.fc_w.to(dtype=x.dtype)) # (bsz, seqlen, mini_moe_experts, hidden) + if self.rand_gain is not None: + exout = exout * self.rand_gain.to(dtype=x.dtype)[None, None, :, :] + x = moe_weights.unsqueeze(-1) * exout + x = x.sum(dim=2) # (bsz, seqlen, hidden) + # x = torch.relu(x) + x = F.leaky_relu(x, negative_slope=0.5) + return self.proj(x.square()) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rng_up: bool = False, + use_rng_gain: bool = False, + mini_moe_experts: int = 1, + train_seq_len: int = 1024, + use_xsa: bool = False, + rope_dims: int | None = None, + ln_scale: bool = False, + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_xsa, train_seq_len, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult, rng_up, use_rng_gain, mini_moe_experts) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x) * self.ln_scale_factor, v_embed=v_embed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * self.ln_scale_factor) + return x + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + rand_proj_layers: list[int], + use_rng_gain: bool = False, + mini_moe_experts: int = 1, + rand_init_seed: int = 42, + rand_init_qr: bool = False, + xsa_layers: list[int] = [], + rope_dims: int | None = None, + ln_scale: bool = False, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + ve_dim: int = 128, + ve_layers: list[int] = [], + train_seq_len: int = 1024, + ): + super().__init__() + self.rand_init_seed = rand_init_seed + self.rand_init_qr = rand_init_qr + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.smear = SmearGate(model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_layers = num_layers + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rng_up=(i in rand_proj_layers), + use_rng_gain=use_rng_gain, + mini_moe_experts=mini_moe_experts, + train_seq_len=train_seq_len, + use_xsa=(i in xsa_layers), + rope_dims=rope_dims, + ln_scale=ln_scale, + layer_idx=i, + ) + for i in range(num_layers) + ] + ) + + # value embeddings + self.ve_layer_indices = ve_layers + self.ve_target_dim = num_kv_heads * (model_dim // num_heads) + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, self.ve_target_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + gen = torch.Generator() + gen.manual_seed(self.rand_init_seed) + + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * self.num_layers)) + + if isinstance(module, MLP) and getattr(module, "fc_w", None) is not None: + # perform seeded random init for the random fc weights + n_experts, d_out, d_in = module.fc_w.shape + if self.rand_init_qr: + for e in range(n_experts): + G = torch.randn((d_out, d_in), generator=gen) + q, _ = torch.linalg.qr(G) + module.fc_w[e].copy_(q) / math.sqrt(d_in) + else: + nn.init.normal_(module.fc_w, mean=0.0, std=1.0/math.sqrt(d_in), generator=gen) + # module.fc_w.bernoulli_(0.5, generator=gen).mul_(2).sub_(1).mul_(1.0 / math.sqrt(d_in)) + + # value-embeddins as per #414 + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict[str, Tensor] | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + ve_cache: dict[str, Tensor] = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + rand_proj_layers=args.rand_proj_layers, + use_rng_gain=args.rand_gain, + mini_moe_experts=args.mini_moe_experts, + train_seq_len=args.train_seq_len, + xsa_layers=args.xsa_layers, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + # compiled_model = base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"num_layers:{args.num_layers} model_dim:{args.model_dim} num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads} mlp_mult:{args.mlp_mult}") + log0(f"rand_proj_layers:{args.rand_proj_layers} rand_gain:{args.rand_gain} mini_moe_experts:{args.mini_moe_experts}") + log0(f"bigram_vocab_size:{args.bigram_vocab_size} bigram_dim:{args.bigram_dim} ve_dim:{args.ve_dim} ve_layers:{args.ve_layers}") + log0(f"tie_embeddings:{args.tie_embeddings} tied_embed_init_std:{args.tied_embed_init_std} logit_softcap:{args.logit_softcap}") + log0(f"rope_base:{args.rope_base} qk_gain_init:{args.qk_gain_init} rope_dims:{args.rope_dims} xsa_layers:{args.xsa_layers} ln_scale:{args.ln_scale}") + log0(f"xsa_layers:{args.xsa_layers} ve_layers:{args.ve_layers} ve_dim:{args.ve_dim}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0(f"mlp_mode:{'rng_up' if args.rand_proj_layers else 'standard'} mlp_mult:{args.mlp_mult} mini_moe_experts:{args.mini_moe_experts} use_rng_gain:{args.rand_gain}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + # enable qat during warmup so we don't pay the compilation tax later + if warmup_step == 2: + CastedLinear.qat_ste = True + if warmup_step == 4: + CastedLinear.qat_ste = False + model.eval() + if warmup_step == 6: + CastedLinear.qat_ste = True + if warmup_step == 8: + model.train() + CastedLinear.qat_ste = False + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + # SWA as per #414 + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # EMA as per #414 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + + val_bpb = float("inf") + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # late QAT as per #414 + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear.qat_ste: + CastedLinear.qat_ste = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + # EMA update as per #414 + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA as per #414 + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply EMA weights (better than SWA alone per PR#401) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + + if diag_val_bpb > val_bpb: + log0( + f"EMA did not improve val_bpb ({val_bpb:.4f} -> {diag_val_bpb:.4f}), restoring pre-EMA weights for final serialization" + ) + base_model.load_state_dict(current_state, strict=True) + else: + log0( + f"EMA improved val_bpb ({val_bpb:.4f} -> {diag_val_bpb:.4f}), keeping EMA weights for final serialization" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # --- quant --- + sd_cpu = {k: v.cpu() for k, v in base_model.state_dict().items()} + # TODO: think a/b keeping routers in fp32 or higher? + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + # quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + rand_proj_layers=args.rand_proj_layers, + use_rng_gain=args.rand_gain, + mini_moe_experts=args.mini_moe_experts, + train_seq_len=args.train_seq_len, + xsa_layers=args.xsa_layers, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + compiled_eval, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int6_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # --- sliding window + TTT eval --- + + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + # disable QAT ste for TTT + CastedLinear.qat_ste = False + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Mon Mar 30 22:19:20 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 24C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 26C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 28C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 24C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9A:00.0 Off | 0 | +| N/A 23C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 | +| N/A 28C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:BA:00.0 Off | 0 | +| N/A 27C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 25C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 232870 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 232871 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 232872 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 232873 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 232874 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 232875 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 232876 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 232877 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:25431141 +num_layers:12 model_dim:512 num_heads:8 num_kv_heads:4 mlp_mult:3 +rand_proj_layers:[0, 1, 2, 3, 4] rand_gain:True mini_moe_experts:1 +bigram_vocab_size:2048 bigram_dim:128 ve_dim:128 ve_layers:[9, 10, 11] +tie_embeddings:True tied_embed_init_std:0.005 logit_softcap:30.0 +rope_base:10000.0 qk_gain_init:1.5 rope_dims:16 xsa_layers:[8, 9, 10, 11] ln_scale:True +xsa_layers:[8, 9, 10, 11] ve_layers:[9, 10, 11] ve_dim:128 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +mlp_mode:rng_up mlp_mult:3 mini_moe_experts:1 use_rng_gain:True +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:999 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9337 train_time:170ms step_avg:169.51ms +step:2/20000 train_loss:8.5387 train_time:303ms step_avg:151.31ms +step:3/20000 train_loss:7.9708 train_time:428ms step_avg:142.81ms +step:4/20000 train_loss:7.1324 train_time:554ms step_avg:138.51ms +step:5/20000 train_loss:6.9860 train_time:680ms step_avg:136.01ms +step:6/20000 train_loss:6.8984 train_time:811ms step_avg:135.15ms +step:7/20000 train_loss:6.7294 train_time:936ms step_avg:133.75ms +step:8/20000 train_loss:6.6772 train_time:1062ms step_avg:132.75ms +step:9/20000 train_loss:6.4003 train_time:1204ms step_avg:133.79ms +step:10/20000 train_loss:6.0868 train_time:1339ms step_avg:133.90ms +step:20/20000 train_loss:4.8528 train_time:2590ms step_avg:129.51ms +step:30/20000 train_loss:4.1874 train_time:3784ms step_avg:126.12ms +step:40/20000 train_loss:3.8836 train_time:4899ms step_avg:122.48ms +step:50/20000 train_loss:3.6802 train_time:6006ms step_avg:120.12ms +step:60/20000 train_loss:3.4392 train_time:7269ms step_avg:121.15ms +step:70/20000 train_loss:3.4630 train_time:8549ms step_avg:122.13ms +step:80/20000 train_loss:3.3108 train_time:9688ms step_avg:121.10ms +step:90/20000 train_loss:3.1737 train_time:10802ms step_avg:120.02ms +step:100/20000 train_loss:3.1981 train_time:11910ms step_avg:119.10ms +step:110/20000 train_loss:3.0824 train_time:13016ms step_avg:118.33ms +step:120/20000 train_loss:3.1065 train_time:14123ms step_avg:117.69ms +step:130/20000 train_loss:2.9621 train_time:19084ms step_avg:146.80ms +step:140/20000 train_loss:2.9464 train_time:20361ms step_avg:145.43ms +step:150/20000 train_loss:2.8831 train_time:21634ms step_avg:144.23ms +step:160/20000 train_loss:2.8487 train_time:22904ms step_avg:143.15ms +step:170/20000 train_loss:2.6756 train_time:24174ms step_avg:142.20ms +step:180/20000 train_loss:2.6915 train_time:25335ms step_avg:140.75ms +step:190/20000 train_loss:2.6732 train_time:26443ms step_avg:139.17ms +step:200/20000 train_loss:2.3587 train_time:27553ms step_avg:137.76ms +step:210/20000 train_loss:2.6162 train_time:28681ms step_avg:136.58ms +step:220/20000 train_loss:2.6544 train_time:29821ms step_avg:135.55ms +step:230/20000 train_loss:2.6091 train_time:30938ms step_avg:134.51ms +step:240/20000 train_loss:2.5729 train_time:32052ms step_avg:133.55ms +step:250/20000 train_loss:2.4423 train_time:33165ms step_avg:132.66ms +step:260/20000 train_loss:2.5231 train_time:38374ms step_avg:147.59ms +step:270/20000 train_loss:2.4277 train_time:39647ms step_avg:146.84ms +step:280/20000 train_loss:2.5793 train_time:40909ms step_avg:146.10ms +step:290/20000 train_loss:2.4683 train_time:42174ms step_avg:145.43ms +step:300/20000 train_loss:2.5265 train_time:43448ms step_avg:144.83ms +step:310/20000 train_loss:2.5401 train_time:44711ms step_avg:144.23ms +step:320/20000 train_loss:2.4494 train_time:45971ms step_avg:143.66ms +step:330/20000 train_loss:2.4601 train_time:47251ms step_avg:143.18ms +step:340/20000 train_loss:2.5555 train_time:48512ms step_avg:142.68ms +step:350/20000 train_loss:2.5237 train_time:49775ms step_avg:142.21ms +step:360/20000 train_loss:2.2931 train_time:51050ms step_avg:141.81ms +step:370/20000 train_loss:2.4947 train_time:52315ms step_avg:141.39ms +step:380/20000 train_loss:2.3551 train_time:53578ms step_avg:140.99ms +step:390/20000 train_loss:2.4819 train_time:58388ms step_avg:149.71ms +step:400/20000 train_loss:2.3965 train_time:59648ms step_avg:149.12ms +step:410/20000 train_loss:2.3904 train_time:60906ms step_avg:148.55ms +step:420/20000 train_loss:2.3320 train_time:62164ms step_avg:148.01ms +step:430/20000 train_loss:2.3401 train_time:63422ms step_avg:147.49ms +step:440/20000 train_loss:2.3901 train_time:64694ms step_avg:147.03ms +step:450/20000 train_loss:2.3452 train_time:65967ms step_avg:146.59ms +step:460/20000 train_loss:2.3846 train_time:67260ms step_avg:146.22ms +step:470/20000 train_loss:2.3875 train_time:68532ms step_avg:145.81ms +step:480/20000 train_loss:2.2230 train_time:69802ms step_avg:145.42ms +step:490/20000 train_loss:2.3783 train_time:71077ms step_avg:145.06ms +step:500/20000 train_loss:2.3819 train_time:72351ms step_avg:144.70ms +step:510/20000 train_loss:2.2936 train_time:78482ms step_avg:153.89ms +step:520/20000 train_loss:2.3440 train_time:79774ms step_avg:153.41ms +step:530/20000 train_loss:2.2254 train_time:81066ms step_avg:152.96ms +step:540/20000 train_loss:2.3109 train_time:82359ms step_avg:152.52ms +step:550/20000 train_loss:2.3173 train_time:83646ms step_avg:152.08ms +step:560/20000 train_loss:2.3229 train_time:84912ms step_avg:151.63ms +step:570/20000 train_loss:2.3717 train_time:86181ms step_avg:151.20ms +step:580/20000 train_loss:2.4150 train_time:87467ms step_avg:150.81ms +step:590/20000 train_loss:2.2475 train_time:88756ms step_avg:150.43ms +step:600/20000 train_loss:2.3204 train_time:90027ms step_avg:150.04ms +step:610/20000 train_loss:2.3161 train_time:91281ms step_avg:149.64ms +step:620/20000 train_loss:2.2863 train_time:92544ms step_avg:149.26ms +step:630/20000 train_loss:2.2727 train_time:93798ms step_avg:148.89ms +step:640/20000 train_loss:2.8132 train_time:99004ms step_avg:154.69ms +step:650/20000 train_loss:2.3153 train_time:100278ms step_avg:154.27ms +step:660/20000 train_loss:2.5050 train_time:101571ms step_avg:153.90ms +step:670/20000 train_loss:2.2540 train_time:102860ms step_avg:153.52ms +step:680/20000 train_loss:2.2403 train_time:104134ms step_avg:153.14ms +step:690/20000 train_loss:2.2754 train_time:105407ms step_avg:152.76ms +step:700/20000 train_loss:2.3343 train_time:106679ms step_avg:152.40ms +step:710/20000 train_loss:2.2633 train_time:107966ms step_avg:152.06ms +step:720/20000 train_loss:2.3573 train_time:109262ms step_avg:151.75ms +step:730/20000 train_loss:2.1494 train_time:110519ms step_avg:151.40ms +step:740/20000 train_loss:2.2519 train_time:111780ms step_avg:151.05ms +step:750/20000 train_loss:2.3141 train_time:113045ms step_avg:150.73ms +step:760/20000 train_loss:2.3974 train_time:114317ms step_avg:150.42ms +step:770/20000 train_loss:2.2688 train_time:119291ms step_avg:154.92ms +step:780/20000 train_loss:2.2214 train_time:120567ms step_avg:154.57ms +step:790/20000 train_loss:2.1708 train_time:121857ms step_avg:154.25ms +step:800/20000 train_loss:2.2227 train_time:123130ms step_avg:153.91ms +step:810/20000 train_loss:2.1785 train_time:124405ms step_avg:153.59ms +step:820/20000 train_loss:2.2208 train_time:125681ms step_avg:153.27ms +step:830/20000 train_loss:2.1770 train_time:126971ms step_avg:152.98ms +step:840/20000 train_loss:2.3274 train_time:128266ms step_avg:152.70ms +step:850/20000 train_loss:2.2193 train_time:129539ms step_avg:152.40ms +step:860/20000 train_loss:2.0266 train_time:130812ms step_avg:152.11ms +step:870/20000 train_loss:2.2722 train_time:132087ms step_avg:151.82ms +step:880/20000 train_loss:2.2078 train_time:133371ms step_avg:151.56ms +step:890/20000 train_loss:2.2412 train_time:134661ms step_avg:151.30ms +step:900/20000 train_loss:2.1144 train_time:139392ms step_avg:154.88ms +step:910/20000 train_loss:2.1742 train_time:140666ms step_avg:154.58ms +step:920/20000 train_loss:2.2154 train_time:141956ms step_avg:154.30ms +step:930/20000 train_loss:2.2615 train_time:143210ms step_avg:153.99ms +step:940/20000 train_loss:2.4053 train_time:144472ms step_avg:153.69ms +step:950/20000 train_loss:2.2053 train_time:145745ms step_avg:153.42ms +step:960/20000 train_loss:2.1561 train_time:147008ms step_avg:153.13ms +step:970/20000 train_loss:2.3790 train_time:148268ms step_avg:152.85ms +step:980/20000 train_loss:2.2053 train_time:149566ms step_avg:152.62ms +step:990/20000 train_loss:2.2454 train_time:150840ms step_avg:152.36ms +step:1000/20000 train_loss:2.2627 train_time:152109ms step_avg:152.11ms +step:1010/20000 train_loss:2.0403 train_time:153381ms step_avg:151.86ms +step:1020/20000 train_loss:2.0776 train_time:158335ms step_avg:155.23ms +step:1030/20000 train_loss:2.1913 train_time:159608ms step_avg:154.96ms +step:1040/20000 train_loss:2.2396 train_time:160883ms step_avg:154.69ms +step:1050/20000 train_loss:2.2069 train_time:162169ms step_avg:154.45ms +step:1060/20000 train_loss:2.3110 train_time:163457ms step_avg:154.20ms +step:1070/20000 train_loss:2.2274 train_time:164733ms step_avg:153.96ms +step:1080/20000 train_loss:2.1856 train_time:166010ms step_avg:153.71ms +step:1090/20000 train_loss:2.1540 train_time:167286ms step_avg:153.47ms +step:1100/20000 train_loss:2.3043 train_time:168577ms step_avg:153.25ms +step:1110/20000 train_loss:2.2147 train_time:169864ms step_avg:153.03ms +step:1120/20000 train_loss:2.0838 train_time:171156ms step_avg:152.82ms +step:1130/20000 train_loss:2.3209 train_time:172493ms step_avg:152.65ms +step:1140/20000 train_loss:2.2319 train_time:173765ms step_avg:152.43ms +step:1150/20000 train_loss:2.2296 train_time:178987ms step_avg:155.64ms +step:1160/20000 train_loss:2.1251 train_time:180268ms step_avg:155.40ms +step:1170/20000 train_loss:2.1596 train_time:181555ms step_avg:155.18ms +step:1180/20000 train_loss:2.1258 train_time:182825ms step_avg:154.94ms +step:1190/20000 train_loss:2.2243 train_time:184093ms step_avg:154.70ms +step:1200/20000 train_loss:2.3365 train_time:185359ms step_avg:154.47ms +step:1210/20000 train_loss:2.1845 train_time:186623ms step_avg:154.23ms +step:1220/20000 train_loss:2.2092 train_time:187884ms step_avg:154.00ms +step:1230/20000 train_loss:2.1450 train_time:189159ms step_avg:153.79ms +step:1240/20000 train_loss:2.1478 train_time:190431ms step_avg:153.57ms +step:1250/20000 train_loss:2.2337 train_time:191708ms step_avg:153.37ms +step:1260/20000 train_loss:2.1536 train_time:192983ms step_avg:153.16ms +step:1270/20000 train_loss:2.1711 train_time:194272ms step_avg:152.97ms +step:1280/20000 train_loss:2.1153 train_time:195653ms step_avg:152.85ms +step:1290/20000 train_loss:2.1420 train_time:196924ms step_avg:152.65ms +step:1300/20000 train_loss:2.3420 train_time:198193ms step_avg:152.46ms +step:1310/20000 train_loss:2.0109 train_time:199471ms step_avg:152.27ms +step:1320/20000 train_loss:2.1667 train_time:200765ms step_avg:152.09ms +step:1330/20000 train_loss:2.1410 train_time:202035ms step_avg:151.91ms +step:1340/20000 train_loss:2.2390 train_time:203304ms step_avg:151.72ms +step:1350/20000 train_loss:2.1521 train_time:204575ms step_avg:151.54ms +step:1360/20000 train_loss:2.2338 train_time:205862ms step_avg:151.37ms +step:1370/20000 train_loss:2.2686 train_time:207134ms step_avg:151.19ms +step:1380/20000 train_loss:2.0869 train_time:208408ms step_avg:151.02ms +step:1390/20000 train_loss:2.1330 train_time:209677ms step_avg:150.85ms +step:1400/20000 train_loss:2.2007 train_time:211056ms step_avg:150.75ms +step:1410/20000 train_loss:2.1559 train_time:212328ms step_avg:150.59ms +step:1420/20000 train_loss:2.1021 train_time:213599ms step_avg:150.42ms +step:1430/20000 train_loss:2.1044 train_time:214860ms step_avg:150.25ms +step:1440/20000 train_loss:2.1919 train_time:216120ms step_avg:150.08ms +step:1450/20000 train_loss:2.1793 train_time:217380ms step_avg:149.92ms +step:1460/20000 train_loss:2.1276 train_time:218649ms step_avg:149.76ms +step:1470/20000 train_loss:2.2560 train_time:219929ms step_avg:149.61ms +step:1480/20000 train_loss:2.1037 train_time:221220ms step_avg:149.47ms +step:1490/20000 train_loss:2.1587 train_time:222479ms step_avg:149.32ms +step:1500/20000 train_loss:2.1710 train_time:223774ms step_avg:149.18ms +step:1510/20000 train_loss:2.3290 train_time:225073ms step_avg:149.05ms +step:1520/20000 train_loss:1.9798 train_time:226372ms step_avg:148.93ms +step:1530/20000 train_loss:2.0093 train_time:227771ms step_avg:148.87ms +step:1540/20000 train_loss:2.1183 train_time:229078ms step_avg:148.75ms +step:1550/20000 train_loss:2.1692 train_time:230411ms step_avg:148.65ms +step:1560/20000 train_loss:2.2297 train_time:231688ms step_avg:148.52ms +step:1570/20000 train_loss:2.1444 train_time:232966ms step_avg:148.39ms +step:1580/20000 train_loss:2.0304 train_time:234263ms step_avg:148.27ms +step:1590/20000 train_loss:2.1167 train_time:235539ms step_avg:148.14ms +step:1600/20000 train_loss:2.1749 train_time:236822ms step_avg:148.01ms +step:1610/20000 train_loss:2.1929 train_time:238107ms step_avg:147.89ms +step:1620/20000 train_loss:2.1639 train_time:239369ms step_avg:147.76ms +step:1630/20000 train_loss:2.3770 train_time:240637ms step_avg:147.63ms +step:1640/20000 train_loss:2.1947 train_time:241895ms step_avg:147.50ms +step:1650/20000 train_loss:1.9785 train_time:243154ms step_avg:147.37ms +step:1660/20000 train_loss:2.2001 train_time:244508ms step_avg:147.29ms +step:1670/20000 train_loss:1.9668 train_time:245769ms step_avg:147.17ms +step:1680/20000 train_loss:2.1801 train_time:247038ms step_avg:147.05ms +step:1690/20000 train_loss:2.1735 train_time:248296ms step_avg:146.92ms +step:1700/20000 train_loss:2.1697 train_time:249555ms step_avg:146.80ms +step:1710/20000 train_loss:2.2293 train_time:250814ms step_avg:146.68ms +step:1720/20000 train_loss:2.2430 train_time:252073ms step_avg:146.55ms +step:1730/20000 train_loss:2.1770 train_time:253367ms step_avg:146.45ms +step:1740/20000 train_loss:2.1328 train_time:254659ms step_avg:146.36ms +step:1750/20000 train_loss:2.1091 train_time:255934ms step_avg:146.25ms +step:1760/20000 train_loss:2.0995 train_time:257186ms step_avg:146.13ms +step:1770/20000 train_loss:2.1144 train_time:258447ms step_avg:146.02ms +step:1780/20000 train_loss:2.0146 train_time:259705ms step_avg:145.90ms +step:1790/20000 train_loss:2.1473 train_time:261077ms step_avg:145.85ms +step:1800/20000 train_loss:2.1173 train_time:262370ms step_avg:145.76ms +step:1810/20000 train_loss:2.0612 train_time:263669ms step_avg:145.67ms +step:1820/20000 train_loss:2.0169 train_time:264961ms step_avg:145.58ms +step:1830/20000 train_loss:2.1456 train_time:266237ms step_avg:145.48ms +step:1840/20000 train_loss:2.2337 train_time:267513ms step_avg:145.39ms +step:1850/20000 train_loss:2.1402 train_time:268791ms step_avg:145.29ms +step:1860/20000 train_loss:2.0418 train_time:270068ms step_avg:145.20ms +step:1870/20000 train_loss:2.0904 train_time:271319ms step_avg:145.09ms +step:1880/20000 train_loss:2.0326 train_time:272434ms step_avg:144.91ms +step:1890/20000 train_loss:2.1133 train_time:273549ms step_avg:144.74ms +step:1900/20000 train_loss:2.1735 train_time:274665ms step_avg:144.56ms +step:1910/20000 train_loss:2.1528 train_time:275875ms step_avg:144.44ms +step:1920/20000 train_loss:2.1527 train_time:277009ms step_avg:144.28ms +step:1930/20000 train_loss:2.1923 train_time:278123ms step_avg:144.11ms +step:1940/20000 train_loss:2.0738 train_time:279399ms step_avg:144.02ms +step:1950/20000 train_loss:2.1113 train_time:280674ms step_avg:143.94ms +step:1960/20000 train_loss:2.0269 train_time:281970ms step_avg:143.86ms +step:1970/20000 train_loss:2.0533 train_time:283264ms step_avg:143.79ms +step:1980/20000 train_loss:2.1410 train_time:284537ms step_avg:143.71ms +step:1990/20000 train_loss:2.0927 train_time:285808ms step_avg:143.62ms +step:2000/20000 train_loss:2.1502 train_time:287081ms step_avg:143.54ms +step:2010/20000 train_loss:1.9393 train_time:288367ms step_avg:143.47ms +step:2020/20000 train_loss:2.1212 train_time:289657ms step_avg:143.39ms +step:2030/20000 train_loss:2.0296 train_time:290919ms step_avg:143.31ms +step:2040/20000 train_loss:2.2407 train_time:292268ms step_avg:143.27ms +step:2050/20000 train_loss:2.0738 train_time:293538ms step_avg:143.19ms +step:2060/20000 train_loss:2.1068 train_time:294799ms step_avg:143.11ms +step:2070/20000 train_loss:2.0603 train_time:296059ms step_avg:143.02ms +step:2080/20000 train_loss:2.0936 train_time:297319ms step_avg:142.94ms +step:2090/20000 train_loss:2.1443 train_time:298579ms step_avg:142.86ms +step:2100/20000 train_loss:2.0397 train_time:299849ms step_avg:142.79ms +step:2110/20000 train_loss:2.1287 train_time:301108ms step_avg:142.71ms +step:2120/20000 train_loss:2.1180 train_time:302367ms step_avg:142.63ms +step:2130/20000 train_loss:2.0520 train_time:303638ms step_avg:142.55ms +step:2140/20000 train_loss:2.1721 train_time:304897ms step_avg:142.48ms +step:2150/20000 train_loss:2.0140 train_time:306157ms step_avg:142.40ms +step:2160/20000 train_loss:2.0556 train_time:307416ms step_avg:142.32ms +step:2170/20000 train_loss:1.9755 train_time:308782ms step_avg:142.30ms +step:2180/20000 train_loss:2.1652 train_time:310049ms step_avg:142.22ms +step:2190/20000 train_loss:2.1895 train_time:311311ms step_avg:142.15ms +step:2200/20000 train_loss:2.1469 train_time:312571ms step_avg:142.08ms +step:2210/20000 train_loss:2.1021 train_time:313843ms step_avg:142.01ms +step:2220/20000 train_loss:1.8619 train_time:315103ms step_avg:141.94ms +step:2230/20000 train_loss:2.1050 train_time:316363ms step_avg:141.87ms +step:2240/20000 train_loss:1.8631 train_time:317623ms step_avg:141.80ms +step:2250/20000 train_loss:2.0841 train_time:318884ms step_avg:141.73ms +step:2260/20000 train_loss:2.1008 train_time:320149ms step_avg:141.66ms +step:2270/20000 train_loss:2.1518 train_time:321408ms step_avg:141.59ms +step:2280/20000 train_loss:2.1040 train_time:322668ms step_avg:141.52ms +step:2290/20000 train_loss:2.1625 train_time:324039ms step_avg:141.50ms +step:2300/20000 train_loss:2.0787 train_time:325300ms step_avg:141.43ms +step:2310/20000 train_loss:2.1013 train_time:326562ms step_avg:141.37ms +step:2320/20000 train_loss:2.0110 train_time:327821ms step_avg:141.30ms +step:2330/20000 train_loss:2.1864 train_time:329081ms step_avg:141.24ms +step:2340/20000 train_loss:2.0447 train_time:330349ms step_avg:141.17ms +step:2350/20000 train_loss:2.1134 train_time:331608ms step_avg:141.11ms +step:2360/20000 train_loss:2.0805 train_time:332867ms step_avg:141.05ms +step:2370/20000 train_loss:2.0161 train_time:333986ms step_avg:140.92ms +step:2380/20000 train_loss:2.1127 train_time:335113ms step_avg:140.80ms +step:2390/20000 train_loss:2.1007 train_time:336230ms step_avg:140.68ms +step:2400/20000 train_loss:2.1424 train_time:337345ms step_avg:140.56ms +step:2410/20000 train_loss:2.0481 train_time:338496ms step_avg:140.45ms +step:2420/20000 train_loss:2.0593 train_time:339730ms step_avg:140.38ms +step:2430/20000 train_loss:3.2297 train_time:340845ms step_avg:140.27ms +step:2440/20000 train_loss:2.0278 train_time:341961ms step_avg:140.15ms +step:2450/20000 train_loss:2.1238 train_time:343099ms step_avg:140.04ms +step:2460/20000 train_loss:2.1396 train_time:344216ms step_avg:139.93ms +step:2470/20000 train_loss:2.0780 train_time:345331ms step_avg:139.81ms +step:2480/20000 train_loss:2.0228 train_time:346447ms step_avg:139.70ms +step:2490/20000 train_loss:2.0405 train_time:347561ms step_avg:139.58ms +step:2500/20000 train_loss:2.0289 train_time:348676ms step_avg:139.47ms +step:2510/20000 train_loss:1.9817 train_time:349813ms step_avg:139.37ms +step:2520/20000 train_loss:2.1049 train_time:350933ms step_avg:139.26ms +step:2530/20000 train_loss:2.0087 train_time:352048ms step_avg:139.15ms +step:2540/20000 train_loss:2.0378 train_time:353165ms step_avg:139.04ms +step:2550/20000 train_loss:2.1159 train_time:354378ms step_avg:138.97ms +step:2560/20000 train_loss:2.0873 train_time:355510ms step_avg:138.87ms +step:2570/20000 train_loss:2.0338 train_time:356626ms step_avg:138.76ms +step:2580/20000 train_loss:2.1520 train_time:357743ms step_avg:138.66ms +step:2590/20000 train_loss:2.0895 train_time:358859ms step_avg:138.56ms +step:2600/20000 train_loss:2.1125 train_time:359975ms step_avg:138.45ms +step:2610/20000 train_loss:2.1355 train_time:361108ms step_avg:138.36ms +step:2620/20000 train_loss:2.0603 train_time:362222ms step_avg:138.25ms +step:2630/20000 train_loss:2.3572 train_time:363340ms step_avg:138.15ms +step:2640/20000 train_loss:2.0177 train_time:364456ms step_avg:138.05ms +step:2650/20000 train_loss:2.0288 train_time:365572ms step_avg:137.95ms +step:2660/20000 train_loss:2.0133 train_time:366710ms step_avg:137.86ms +step:2670/20000 train_loss:2.1415 train_time:367823ms step_avg:137.76ms +step:2680/20000 train_loss:1.9202 train_time:369038ms step_avg:137.70ms +step:2690/20000 train_loss:2.1967 train_time:370154ms step_avg:137.60ms +step:2700/20000 train_loss:2.0655 train_time:371271ms step_avg:137.51ms +step:2710/20000 train_loss:2.0430 train_time:372404ms step_avg:137.42ms +step:2720/20000 train_loss:2.0831 train_time:373519ms step_avg:137.32ms +step:2730/20000 train_loss:2.0291 train_time:374635ms step_avg:137.23ms +step:2740/20000 train_loss:2.0784 train_time:375752ms step_avg:137.14ms +step:2750/20000 train_loss:2.1046 train_time:376870ms step_avg:137.04ms +step:2760/20000 train_loss:2.0795 train_time:378016ms step_avg:136.96ms +step:2770/20000 train_loss:1.9967 train_time:379267ms step_avg:136.92ms +step:2780/20000 train_loss:2.3550 train_time:380399ms step_avg:136.83ms +step:2790/20000 train_loss:2.0034 train_time:381515ms step_avg:136.74ms +step:2800/20000 train_loss:2.0785 train_time:382728ms step_avg:136.69ms +step:2810/20000 train_loss:2.0112 train_time:383843ms step_avg:136.60ms +step:2820/20000 train_loss:2.0976 train_time:384999ms step_avg:136.52ms +step:2830/20000 train_loss:1.9629 train_time:386221ms step_avg:136.47ms +step:2840/20000 train_loss:2.0268 train_time:387337ms step_avg:136.39ms +step:2850/20000 train_loss:2.0734 train_time:388453ms step_avg:136.30ms +step:2860/20000 train_loss:2.0838 train_time:389568ms step_avg:136.21ms +step:2870/20000 train_loss:2.0248 train_time:390700ms step_avg:136.13ms +step:2880/20000 train_loss:2.0259 train_time:391817ms step_avg:136.05ms +step:2890/20000 train_loss:2.1369 train_time:392934ms step_avg:135.96ms +step:2900/20000 train_loss:2.0122 train_time:394051ms step_avg:135.88ms +step:2910/20000 train_loss:2.0710 train_time:395167ms step_avg:135.80ms +step:2920/20000 train_loss:2.0825 train_time:396298ms step_avg:135.72ms +step:2930/20000 train_loss:2.0682 train_time:397507ms step_avg:135.67ms +step:2940/20000 train_loss:1.8827 train_time:398623ms step_avg:135.59ms +step:2950/20000 train_loss:2.0911 train_time:399741ms step_avg:135.51ms +step:2960/20000 train_loss:2.0379 train_time:400856ms step_avg:135.42ms +step:2970/20000 train_loss:1.9925 train_time:401970ms step_avg:135.34ms +step:2980/20000 train_loss:2.0343 train_time:403096ms step_avg:135.27ms +step:2990/20000 train_loss:2.1138 train_time:404212ms step_avg:135.19ms +step:3000/20000 train_loss:2.1016 train_time:405331ms step_avg:135.11ms +step:3010/20000 train_loss:2.0520 train_time:406448ms step_avg:135.03ms +step:3020/20000 train_loss:2.0041 train_time:407565ms step_avg:134.96ms +step:3030/20000 train_loss:1.9418 train_time:408681ms step_avg:134.88ms +step:3040/20000 train_loss:2.0355 train_time:409809ms step_avg:134.81ms +step:3050/20000 train_loss:2.0261 train_time:410924ms step_avg:134.73ms +step:3060/20000 train_loss:2.0362 train_time:412139ms step_avg:134.69ms +step:3070/20000 train_loss:1.9926 train_time:413254ms step_avg:134.61ms +step:3080/20000 train_loss:2.0112 train_time:414370ms step_avg:134.54ms +step:3090/20000 train_loss:1.9650 train_time:415504ms step_avg:134.47ms +step:3100/20000 train_loss:2.0321 train_time:416638ms step_avg:134.40ms +step:3110/20000 train_loss:2.4338 train_time:417830ms step_avg:134.35ms +step:3120/20000 train_loss:2.1333 train_time:419088ms step_avg:134.32ms +step:3130/20000 train_loss:2.1353 train_time:420347ms step_avg:134.30ms +step:3140/20000 train_loss:1.9825 train_time:421605ms step_avg:134.27ms +step:3150/20000 train_loss:2.0657 train_time:422863ms step_avg:134.24ms +step:3160/20000 train_loss:1.8884 train_time:424121ms step_avg:134.22ms +step:3170/20000 train_loss:2.0172 train_time:425379ms step_avg:134.19ms +step:3180/20000 train_loss:2.0417 train_time:426747ms step_avg:134.20ms +step:3190/20000 train_loss:2.0342 train_time:428019ms step_avg:134.18ms +step:3200/20000 train_loss:1.8355 train_time:429282ms step_avg:134.15ms +step:3210/20000 train_loss:2.1660 train_time:430549ms step_avg:134.13ms +step:3220/20000 train_loss:2.2583 train_time:431810ms step_avg:134.10ms +step:3230/20000 train_loss:1.9891 train_time:433072ms step_avg:134.08ms +step:3240/20000 train_loss:1.9494 train_time:434347ms step_avg:134.06ms +step:3250/20000 train_loss:2.0494 train_time:435610ms step_avg:134.03ms +step:3260/20000 train_loss:1.9255 train_time:436870ms step_avg:134.01ms +step:3270/20000 train_loss:2.0852 train_time:438142ms step_avg:133.99ms +step:3280/20000 train_loss:2.0120 train_time:439404ms step_avg:133.96ms +step:3290/20000 train_loss:2.0056 train_time:440664ms step_avg:133.94ms +step:3300/20000 train_loss:1.9840 train_time:441926ms step_avg:133.92ms +step:3310/20000 train_loss:2.1784 train_time:443285ms step_avg:133.92ms +step:3320/20000 train_loss:2.0324 train_time:444549ms step_avg:133.90ms +step:3330/20000 train_loss:1.8488 train_time:445808ms step_avg:133.88ms +step:3340/20000 train_loss:2.0446 train_time:447067ms step_avg:133.85ms +step:3350/20000 train_loss:1.9841 train_time:448339ms step_avg:133.83ms +step:3360/20000 train_loss:1.9562 train_time:449596ms step_avg:133.81ms +step:3370/20000 train_loss:1.9870 train_time:450824ms step_avg:133.78ms +step:3380/20000 train_loss:2.1699 train_time:451940ms step_avg:133.71ms +step:3390/20000 train_loss:2.0295 train_time:453055ms step_avg:133.64ms +step:3400/20000 train_loss:2.0581 train_time:454169ms step_avg:133.58ms +step:3410/20000 train_loss:2.2045 train_time:455301ms step_avg:133.52ms +step:3420/20000 train_loss:1.8546 train_time:456414ms step_avg:133.45ms +step:3430/20000 train_loss:2.0583 train_time:457528ms step_avg:133.39ms +step:3440/20000 train_loss:2.0117 train_time:458770ms step_avg:133.36ms +step:3450/20000 train_loss:2.0549 train_time:459901ms step_avg:133.30ms +step:3460/20000 train_loss:1.9803 train_time:461016ms step_avg:133.24ms +step:3470/20000 train_loss:1.9863 train_time:462132ms step_avg:133.18ms +step:3480/20000 train_loss:2.0757 train_time:463245ms step_avg:133.12ms +step:3490/20000 train_loss:2.0208 train_time:464376ms step_avg:133.06ms +step:3500/20000 train_loss:2.0201 train_time:465584ms step_avg:133.02ms +step:3510/20000 train_loss:2.1087 train_time:466844ms step_avg:133.00ms +step:3520/20000 train_loss:1.9798 train_time:468102ms step_avg:132.98ms +step:3530/20000 train_loss:2.0557 train_time:469361ms step_avg:132.96ms +step:3540/20000 train_loss:1.9774 train_time:470620ms step_avg:132.94ms +step:3550/20000 train_loss:2.0229 train_time:471878ms step_avg:132.92ms +step:3560/20000 train_loss:2.0483 train_time:473147ms step_avg:132.91ms +step:3570/20000 train_loss:2.0519 train_time:474520ms step_avg:132.92ms +step:3580/20000 train_loss:2.0125 train_time:475777ms step_avg:132.90ms +step:3590/20000 train_loss:2.0453 train_time:477049ms step_avg:132.88ms +step:3600/20000 train_loss:2.0185 train_time:478308ms step_avg:132.86ms +step:3610/20000 train_loss:1.9704 train_time:479567ms step_avg:132.84ms +step:3620/20000 train_loss:2.1524 train_time:480836ms step_avg:132.83ms +step:3630/20000 train_loss:2.0498 train_time:482096ms step_avg:132.81ms +step:3640/20000 train_loss:2.0200 train_time:483354ms step_avg:132.79ms +step:3650/20000 train_loss:1.9644 train_time:484612ms step_avg:132.77ms +step:3660/20000 train_loss:1.9816 train_time:485869ms step_avg:132.75ms +step:3670/20000 train_loss:2.1027 train_time:487136ms step_avg:132.73ms +step:3680/20000 train_loss:1.9530 train_time:488395ms step_avg:132.72ms +step:3690/20000 train_loss:1.9901 train_time:489759ms step_avg:132.73ms +step:3700/20000 train_loss:1.9605 train_time:491019ms step_avg:132.71ms +step:3710/20000 train_loss:2.0129 train_time:492277ms step_avg:132.69ms +step:3720/20000 train_loss:2.0617 train_time:493547ms step_avg:132.67ms +step:3730/20000 train_loss:2.0320 train_time:494806ms step_avg:132.66ms +step:3740/20000 train_loss:2.0740 train_time:496065ms step_avg:132.64ms +step:3750/20000 train_loss:2.0737 train_time:497324ms step_avg:132.62ms +step:3760/20000 train_loss:2.0172 train_time:498583ms step_avg:132.60ms +step:3770/20000 train_loss:2.0753 train_time:499846ms step_avg:132.59ms +step:3780/20000 train_loss:1.9744 train_time:501104ms step_avg:132.57ms +step:3790/20000 train_loss:2.0200 train_time:502362ms step_avg:132.55ms +step:3800/20000 train_loss:2.0042 train_time:503621ms step_avg:132.53ms +step:3810/20000 train_loss:1.9442 train_time:504880ms step_avg:132.51ms +step:3820/20000 train_loss:2.0970 train_time:506268ms step_avg:132.53ms +step:3830/20000 train_loss:1.9935 train_time:507537ms step_avg:132.52ms +step:3840/20000 train_loss:1.8335 train_time:508798ms step_avg:132.50ms +swa:start step:3850 +step:3850/20000 train_loss:2.0576 train_time:510057ms step_avg:132.48ms +step:3860/20000 train_loss:2.0242 train_time:511417ms step_avg:132.49ms +step:3870/20000 train_loss:1.9921 train_time:512676ms step_avg:132.47ms +step:3880/20000 train_loss:2.0061 train_time:513949ms step_avg:132.46ms +step:3890/20000 train_loss:2.0048 train_time:515248ms step_avg:132.45ms +step:3900/20000 train_loss:1.9876 train_time:516476ms step_avg:132.43ms +step:3910/20000 train_loss:2.0445 train_time:517620ms step_avg:132.38ms +step:3920/20000 train_loss:1.9340 train_time:518734ms step_avg:132.33ms +step:3930/20000 train_loss:1.9473 train_time:519850ms step_avg:132.28ms +step:3940/20000 train_loss:1.9607 train_time:520966ms step_avg:132.22ms +step:3950/20000 train_loss:1.9777 train_time:522225ms step_avg:132.21ms +step:3960/20000 train_loss:1.9965 train_time:523499ms step_avg:132.20ms +step:3970/20000 train_loss:2.0572 train_time:524748ms step_avg:132.18ms +step:3980/20000 train_loss:2.0013 train_time:525997ms step_avg:132.16ms +step:3990/20000 train_loss:2.0213 train_time:527246ms step_avg:132.14ms +step:4000/20000 train_loss:1.9190 train_time:528495ms step_avg:132.12ms +step:4010/20000 train_loss:2.2828 train_time:529770ms step_avg:132.11ms +late_qat:enabled step:4017 scale:0.1499 +step:4020/20000 train_loss:1.9599 train_time:531067ms step_avg:132.11ms +step:4030/20000 train_loss:2.0330 train_time:532331ms step_avg:132.09ms +step:4040/20000 train_loss:1.9742 train_time:533580ms step_avg:132.07ms +step:4050/20000 train_loss:2.0379 train_time:534724ms step_avg:132.03ms +step:4060/20000 train_loss:2.0698 train_time:535873ms step_avg:131.99ms +step:4070/20000 train_loss:2.0130 train_time:537123ms step_avg:131.97ms +step:4080/20000 train_loss:1.9053 train_time:538253ms step_avg:131.92ms +step:4090/20000 train_loss:1.9751 train_time:539380ms step_avg:131.88ms +step:4100/20000 train_loss:1.8848 train_time:540524ms step_avg:131.84ms +step:4110/20000 train_loss:2.0617 train_time:541677ms step_avg:131.79ms +step:4120/20000 train_loss:1.8273 train_time:542821ms step_avg:131.75ms +step:4130/20000 train_loss:1.9445 train_time:543950ms step_avg:131.71ms +step:4140/20000 train_loss:1.8352 train_time:545083ms step_avg:131.66ms +step:4150/20000 train_loss:2.0500 train_time:546221ms step_avg:131.62ms +step:4160/20000 train_loss:1.9947 train_time:547378ms step_avg:131.58ms +step:4170/20000 train_loss:1.9284 train_time:548524ms step_avg:131.54ms +step:4180/20000 train_loss:1.9951 train_time:549654ms step_avg:131.50ms +step:4190/20000 train_loss:2.0682 train_time:550784ms step_avg:131.45ms +step:4200/20000 train_loss:2.0261 train_time:552046ms step_avg:131.44ms +step:4210/20000 train_loss:1.9748 train_time:553203ms step_avg:131.40ms +step:4220/20000 train_loss:1.9507 train_time:554332ms step_avg:131.36ms +step:4230/20000 train_loss:1.8089 train_time:555461ms step_avg:131.31ms +step:4240/20000 train_loss:1.9999 train_time:556591ms step_avg:131.27ms +step:4250/20000 train_loss:1.9946 train_time:557759ms step_avg:131.24ms +step:4260/20000 train_loss:2.0113 train_time:559085ms step_avg:131.24ms +step:4270/20000 train_loss:1.9907 train_time:560384ms step_avg:131.24ms +step:4280/20000 train_loss:2.0199 train_time:561684ms step_avg:131.23ms +step:4290/20000 train_loss:1.9337 train_time:562985ms step_avg:131.23ms +step:4300/20000 train_loss:1.9795 train_time:564282ms step_avg:131.23ms +step:4310/20000 train_loss:1.9398 train_time:565597ms step_avg:131.23ms +step:4320/20000 train_loss:2.1163 train_time:566891ms step_avg:131.22ms +step:4330/20000 train_loss:2.0011 train_time:568292ms step_avg:131.25ms +step:4340/20000 train_loss:2.0117 train_time:569521ms step_avg:131.23ms +step:4350/20000 train_loss:1.8956 train_time:570699ms step_avg:131.20ms +step:4360/20000 train_loss:1.9639 train_time:571894ms step_avg:131.17ms +step:4370/20000 train_loss:2.0074 train_time:573032ms step_avg:131.13ms +step:4380/20000 train_loss:1.8752 train_time:574169ms step_avg:131.09ms +step:4390/20000 train_loss:2.0619 train_time:575319ms step_avg:131.05ms +step:4400/20000 train_loss:2.0275 train_time:576456ms step_avg:131.01ms +step:4410/20000 train_loss:1.9045 train_time:577629ms step_avg:130.98ms +step:4420/20000 train_loss:1.8925 train_time:578760ms step_avg:130.94ms +step:4430/20000 train_loss:2.0232 train_time:579891ms step_avg:130.90ms +step:4440/20000 train_loss:2.0123 train_time:581026ms step_avg:130.86ms +step:4450/20000 train_loss:1.9080 train_time:582158ms step_avg:130.82ms +step:4460/20000 train_loss:1.8835 train_time:583422ms step_avg:130.81ms +step:4470/20000 train_loss:2.0703 train_time:584553ms step_avg:130.77ms +step:4480/20000 train_loss:1.8280 train_time:585772ms step_avg:130.75ms +step:4490/20000 train_loss:1.8831 train_time:587050ms step_avg:130.75ms +step:4500/20000 train_loss:1.9692 train_time:588315ms step_avg:130.74ms +step:4510/20000 train_loss:2.0228 train_time:589632ms step_avg:130.74ms +step:4520/20000 train_loss:1.9199 train_time:590897ms step_avg:130.73ms +step:4530/20000 train_loss:1.8477 train_time:592163ms step_avg:130.72ms +step:4540/20000 train_loss:1.9657 train_time:593312ms step_avg:130.69ms +step:4550/20000 train_loss:1.8784 train_time:594443ms step_avg:130.65ms +step:4560/20000 train_loss:1.9013 train_time:595629ms step_avg:130.62ms +step:4570/20000 train_loss:1.9556 train_time:596764ms step_avg:130.58ms +step:4580/20000 train_loss:2.0113 train_time:598010ms step_avg:130.57ms +step:4590/20000 train_loss:1.9622 train_time:599142ms step_avg:130.53ms +step:4598/20000 val_loss:1.9768 val_bpb:1.1708 train_time:600050ms step_avg:130.50ms +stopping_early: wallclock_cap train_time:600050ms step:4598/20000 +peak memory allocated: 22052 MiB reserved: 22132 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9734 val_bpb:1.1687 eval_time:2105ms +EMA improved val_bpb (1.1708 -> 1.1687), keeping EMA weights for final serialization +Serialized model: 99922207 bytes +Code size: 88321 bytes +Total submission size: 100010528 bytes +Serialized model int6+zstd: 14366226 bytes +Total submission size int6+zstd: 14454547 bytes +Total submission size int8+zlib: 14454547 bytes +final_int6_zlib_roundtrip val_loss:1.9917 val_bpb:1.1796 eval_time:6520ms +final_int6_zlib_roundtrip_exact val_loss:1.99171974 val_bpb:1.17960862 +final_int6_sliding_window val_loss:1.9515 val_bpb:1.1558 stride:64 eval_time:79684ms +final_int6_sliding_window_exact val_loss:1.95153499 val_bpb:1.15581201 +final_int8_zlib_roundtrip_exact val_loss:1.95153499 val_bpb:1.15581201 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 +ttt_sliding:params unfrozen=25402880 frozen=0 + ttt_chunk [1/1893] bpb=1.186077 time=0.6s + ttt_chunk [11/1893] bpb=1.428367 time=3.8s + ttt_chunk [21/1893] bpb=1.410195 time=7.9s + ttt_chunk [31/1893] bpb=1.395789 time=11.9s + ttt_chunk [41/1893] bpb=1.373849 time=15.8s + ttt_chunk [51/1893] bpb=1.365005 time=19.0s + ttt_chunk [61/1893] bpb=1.367875 time=22.3s + ttt_chunk [71/1893] bpb=1.359849 time=25.5s + ttt_chunk [81/1893] bpb=1.355835 time=28.8s + ttt_chunk [91/1893] bpb=1.352681 time=32.0s + ttt_chunk [101/1893] bpb=1.353660 time=35.2s + ttt_chunk [111/1893] bpb=1.352968 time=38.5s diff --git a/records/track_10min_16mb/2026-04-01_random_mlp/run.sh b/records/track_10min_16mb/2026-04-01_random_mlp/run.sh new file mode 100755 index 0000000000..9083b6aad5 --- /dev/null +++ b/records/track_10min_16mb/2026-04-01_random_mlp/run.sh @@ -0,0 +1,20 @@ +RUN_ID=baseline_no_moe_xsa4_ve3_12l_5r7l \ +ITERATIONS=20000 \ +TRAIN_BATCH_TOKENS=786432 \ +TRAIN_SEQ_LEN=2048 \ +VAL_LOSS_EVERY=0 \ +VAL_BATCH_SIZE=786432 \ +WARMDOWN_ITERS=3500 \ +NUM_LAYERS=12 \ +TRAIN_LOG_EVERY=10 \ +MLP_MULT=3 \ +RAND_PROJ_LAYERS="0,1,2,3,4" \ +RAND_GAIN=1 \ +RAND_INIT_QR=1 \ +MINI_MOE_EXPERTS=1 \ +VE_LAYERS="9,10,11" \ +VE_DIM=128 \ +XSA_LAYERS="8,9,10,11" \ +BIGRAM_VOCAB_SIZE=2048 \ +SEED=1337 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py \ No newline at end of file diff --git a/records/track_10min_16mb/2026-04-01_random_mlp/train_gpt.py b/records/track_10min_16mb/2026-04-01_random_mlp/train_gpt.py new file mode 100644 index 0000000000..091c12005e --- /dev/null +++ b/records/track_10min_16mb/2026-04-01_random_mlp/train_gpt.py @@ -0,0 +1,2006 @@ +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +from flash_attn_interface import flash_attn_func +# from flash_attn import flash_attn_func + +# make dynamo less complainy +import torch._dynamo +torch._dynamo.config.cache_size_limit = 64 + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # random projection control + rand_proj_layers = [int(x) for x in os.environ.get("RAND_PROJ_LAYERS", "").split(",") if x] + rand_gain = bool(int(os.environ.get("RAND_GAIN", "0"))) + mini_moe_experts = int(os.environ.get("MINI_MOE_EXPERTS", 1)) + rand_init_qr = bool(int(os.environ.get("RAND_INIT_QR", "1"))) + + # xsa control + xsa_layers = [int(x) for x in os.environ.get("XSA_LAYERS", "").split(",") if x] + + # quant + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + + # partial rope + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + + # ln scale + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + + # bigram control + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # value embedding control + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = [int(x) for x in os.environ.get("VE_LAYERS", "9,10").split(",") if x] + + # ttt + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + + # swa + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# --- sliding window eval --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + # ttt_params = [] + # for name, p in base_model.named_parameters(): + # freeze = False + # for bi in frozen_block_ids: + # if f"blocks.{bi}." in name: + # freeze = True + # break + # if freeze: + # p.requires_grad_(False) + # else: + # p.requires_grad_(True) + # ttt_params.append(p) + ttt_params = [ + p for name, p in base_model.named_parameters() + if p.ndim == 2 and not any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS) + and not any(f"blocks.{bi}." in name for bi in frozen_block_ids)] + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + # optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + optimizer = Muon(ttt_params, lr=args.ttt_lr, momentum=0.0, backend_steps=5, weight_decay=0.0) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # workaround for caching issues with sin/cos in rotary layers: + for block in base_model.blocks: + block.attn.rotary._cos_cached = None + block.attn.rotary._sin_cached = None + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + +# --- int6 quant as per #414 --- +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + # TODO: not all mlp params likely want int6, e.g. scales and gates + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6( + state_dict: dict[str, Tensor], + int6_cats: set[str] +): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6( + result: dict[str, Tensor], + meta: dict[str, object], + template_sd: dict[str, Tensor] +) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # straight-through qat as per #137 + qat_ste: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self.qat_ste and self.training: + # STE fake int6: quantize to [-32,31] and dequantize, gradient flows through + w32 = w.float() + row_max = w32.abs().amax(dim=1).clamp_min(1e-8) + scale = row_max / 31.0 + w_q = torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None] + w = w + (w_q.to(w.dtype) - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +# using partial RoPE as per #414 +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + use_xsa: bool = False, + train_seq_len: int = 1024, + rope_dims: int | None = None + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len, rope_dims=rope_dims) + self.use_xsa = use_xsa + + # efficient XSA as per + # https://github.com/unnir/parameter-golf/blob/a81f85bd7f632e3e48ef6b1da0017b81d25998a7/records/track_10min_16mb/2026-03-20_11L_EfficientPartialXSA_FA3_SWA120/train_gpt.py + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + # Reshape y into KV head groups — free view, no memory alloc + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + # Project out self-value component per KV head group + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rotary.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rotary.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_func(q, k, v, causal=True) + + # efficient xsa as per https://github.com/openai/parameter-golf/pull/265 + if self.use_xsa: + y = self._xsa_efficient(y, v) + + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +# stolen from #414 +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +# stolen from #414 +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +# stolen from #414 +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__( + self, + dim: int, + mlp_mult: int, + rng_up: bool = False, + use_rng_gain: bool = False, + mini_moe_experts: int = 1, + ): + super().__init__() + self.hidden = mlp_mult * dim + self.mini_moe_experts = mini_moe_experts + self.use_rng_gain = use_rng_gain + self.fc: CastedLinear | None = None + self.rand_gain: nn.Parameter | None = None + self.moe_router: CastedLinear | None = None + if not rng_up: + self.fc = CastedLinear(dim, self.hidden, bias=False) + else: + # non-persistent buffer, because deterministic random init, kept in bf16 because not quant'd anyway + self.register_buffer("fc_w", torch.empty((mini_moe_experts, self.hidden, dim), dtype=torch.bfloat16), persistent=False) + if self.use_rng_gain: + self.rand_gain = nn.Parameter(torch.ones(mini_moe_experts, self.hidden, dtype=torch.float32)) if rng_up else None + if self.mini_moe_experts > 1: + self.moe_router = CastedLinear(dim, mini_moe_experts, bias=False) + self.moe_router._zero_init = True + self.proj = CastedLinear(self.hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + if self.fc is not None: + x = self.fc(x) + else: + if self.mini_moe_experts == 1: + x = F.linear(x, self.fc_w[0].to(dtype=x.dtype)) + if self.rand_gain is not None: + x = x * self.rand_gain[0].to(dtype=x.dtype)[None, :] + else: + # compute router + moe_weights = F.softmax(self.moe_router(x), dim=-1) + # compute individual expert outputs and weighted sum + # x: (bsz, seqlen, dim), fc_w: (mini_moe_experts, dim, hidden) + exout = torch.einsum("bsd,ehd->bseh", x, self.fc_w.to(dtype=x.dtype)) # (bsz, seqlen, mini_moe_experts, hidden) + if self.rand_gain is not None: + exout = exout * self.rand_gain.to(dtype=x.dtype)[None, None, :, :] + x = moe_weights.unsqueeze(-1) * exout + x = x.sum(dim=2) # (bsz, seqlen, hidden) + # x = torch.relu(x) + x = F.leaky_relu(x, negative_slope=0.5) + return self.proj(x.square()) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rng_up: bool = False, + use_rng_gain: bool = False, + mini_moe_experts: int = 1, + train_seq_len: int = 1024, + use_xsa: bool = False, + rope_dims: int | None = None, + ln_scale: bool = False, + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_xsa, train_seq_len, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult, rng_up, use_rng_gain, mini_moe_experts) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x) * self.ln_scale_factor, v_embed=v_embed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * self.ln_scale_factor) + return x + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + rand_proj_layers: list[int], + use_rng_gain: bool = False, + mini_moe_experts: int = 1, + rand_init_seed: int = 42, + rand_init_qr: bool = False, + xsa_layers: list[int] = [], + rope_dims: int | None = None, + ln_scale: bool = False, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + ve_dim: int = 128, + ve_layers: list[int] = [], + train_seq_len: int = 1024, + ): + super().__init__() + self.rand_init_seed = rand_init_seed + self.rand_init_qr = rand_init_qr + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.smear = SmearGate(model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_layers = num_layers + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rng_up=(i in rand_proj_layers), + use_rng_gain=use_rng_gain, + mini_moe_experts=mini_moe_experts, + train_seq_len=train_seq_len, + use_xsa=(i in xsa_layers), + rope_dims=rope_dims, + ln_scale=ln_scale, + layer_idx=i, + ) + for i in range(num_layers) + ] + ) + + # value embeddings + self.ve_layer_indices = ve_layers + self.ve_target_dim = num_kv_heads * (model_dim // num_heads) + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, self.ve_target_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + gen = torch.Generator() + gen.manual_seed(self.rand_init_seed) + + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * self.num_layers)) + + if isinstance(module, MLP) and getattr(module, "fc_w", None) is not None: + # perform seeded random init for the random fc weights + n_experts, d_out, d_in = module.fc_w.shape + if self.rand_init_qr: + for e in range(n_experts): + G = torch.randn((d_out, d_in), generator=gen) + q, _ = torch.linalg.qr(G) + module.fc_w[e].copy_(q) / math.sqrt(d_in) + else: + nn.init.normal_(module.fc_w, mean=0.0, std=1.0/math.sqrt(d_in), generator=gen) + # module.fc_w.bernoulli_(0.5, generator=gen).mul_(2).sub_(1).mul_(1.0 / math.sqrt(d_in)) + + # value-embeddins as per #414 + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict[str, Tensor] | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + ve_cache: dict[str, Tensor] = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + rand_proj_layers=args.rand_proj_layers, + use_rng_gain=args.rand_gain, + mini_moe_experts=args.mini_moe_experts, + train_seq_len=args.train_seq_len, + xsa_layers=args.xsa_layers, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + # compiled_model = base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"num_layers:{args.num_layers} model_dim:{args.model_dim} num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads} mlp_mult:{args.mlp_mult}") + log0(f"rand_proj_layers:{args.rand_proj_layers} rand_gain:{args.rand_gain} mini_moe_experts:{args.mini_moe_experts}") + log0(f"bigram_vocab_size:{args.bigram_vocab_size} bigram_dim:{args.bigram_dim} ve_dim:{args.ve_dim} ve_layers:{args.ve_layers}") + log0(f"tie_embeddings:{args.tie_embeddings} tied_embed_init_std:{args.tied_embed_init_std} logit_softcap:{args.logit_softcap}") + log0(f"rope_base:{args.rope_base} qk_gain_init:{args.qk_gain_init} rope_dims:{args.rope_dims} xsa_layers:{args.xsa_layers} ln_scale:{args.ln_scale}") + log0(f"xsa_layers:{args.xsa_layers} ve_layers:{args.ve_layers} ve_dim:{args.ve_dim}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0(f"mlp_mode:{'rng_up' if args.rand_proj_layers else 'standard'} mlp_mult:{args.mlp_mult} mini_moe_experts:{args.mini_moe_experts} use_rng_gain:{args.rand_gain}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + # enable qat during warmup so we don't pay the compilation tax later + if warmup_step == 2: + CastedLinear.qat_ste = True + if warmup_step == 4: + CastedLinear.qat_ste = False + model.eval() + if warmup_step == 6: + CastedLinear.qat_ste = True + if warmup_step == 8: + model.train() + CastedLinear.qat_ste = False + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + # SWA as per #414 + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # EMA as per #414 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + + val_bpb = float("inf") + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # late QAT as per #414 + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear.qat_ste: + CastedLinear.qat_ste = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + # EMA update as per #414 + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA as per #414 + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply EMA weights (better than SWA alone per PR#401) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + + if diag_val_bpb > val_bpb: + log0( + f"EMA did not improve val_bpb ({val_bpb:.4f} -> {diag_val_bpb:.4f}), restoring pre-EMA weights for final serialization" + ) + base_model.load_state_dict(current_state, strict=True) + else: + log0( + f"EMA improved val_bpb ({val_bpb:.4f} -> {diag_val_bpb:.4f}), keeping EMA weights for final serialization" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # --- quant --- + sd_cpu = {k: v.cpu() for k, v in base_model.state_dict().items()} + # TODO: think a/b keeping routers in fp32 or higher? + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + # quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + rand_proj_layers=args.rand_proj_layers, + use_rng_gain=args.rand_gain, + mini_moe_experts=args.mini_moe_experts, + train_seq_len=args.train_seq_len, + xsa_layers=args.xsa_layers, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + compiled_eval, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int6_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # --- sliding window + TTT eval --- + + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + # disable QAT ste for TTT + CastedLinear.qat_ste = False + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()