diff --git a/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/README.md b/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/README.md new file mode 100644 index 0000000000..1812cb27ef --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/README.md @@ -0,0 +1,63 @@ +# 11L + Multi-Order N-gram Backoff + Entropy-Adaptive Alpha + +**val_bpb: 0.6678** (3-seed mean, std 0.0008) | **15.0 MB** artifact | 1xB200 (HiPerGator) + +## Technique + +Base 11L SOTA architecture with a novel eval-time n-gram cache that provides -0.49 BPB improvement over neural-only sliding eval. + +### Multi-order N-gram Backoff (orders 2-7) + +During sliding window evaluation, we maintain hash tables for n-gram contexts of orders 2 through 7. For each token prediction, we attempt the highest order first and cascade down on miss. This captures repeated patterns within documents that the neural model cannot access outside its context window. + +### Entropy-Adaptive Alpha + +Instead of a fixed interpolation weight, alpha adapts based on the model's own entropy: + +``` +alpha = 0.05 + 0.55 * sigmoid(2 * (H - 4.0)) +``` + +- Low entropy (model confident): alpha -> 0.05, trust the LM +- High entropy (model uncertain): alpha -> 0.60, trust the n-gram cache + +### Compliance + +- Score-first, backward-looking: n-gram counts built from previously scored tokens only +- No oracle selection: alpha depends on model entropy, never on ground-truth labels +- Single blended prediction per token, no min(NLL) + +## Results (3 seeds) + +| Seed | N-gram BPB | Artifact | +|------|-----------|----------| +| 42 | **0.6672** | 15,025,238 | +| 1337 | **0.6676** | 15,025,238 | +| 7 | **0.6687** | 15,025,238 | +| **Mean** | **0.6678 (std 0.0008)** | | + +## Architecture + +- 11L, 512d, 8H/4KV GQA, MLP 3x +- XSA last 4 layers, Partial RoPE (16/64), LN Scale +- Value Embeddings (VE128, layers 9-10) +- SmearGate + BigramHash(2048) +- EMA (0.997), Late QAT (0.15), OrthoInit +- Int6 per-row + GPTQ-lite + 3% magnitude pruning + zstd-22 + +## Reproduction + +```bash +pip install sentencepiece zstandard +python3 data/cached_challenge_fineweb.py --variant sp1024 + +SEED=42 NGRAM_CACHE=1 NGRAM_ORDER=7 NGRAM_MIN_ORDER=2 \ +NGRAM_ENTROPY=1 EVAL_STRIDE=64 PRUNE_PCT=0.03 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py +``` + +## Credits + +- Base architecture: PR #414 (signalrush), PR #315 (jfprincz), PR #287 (jfprincz) +- N-gram cache concept: PR #702 (lukacf), PR #727 (lukacf) +- Entropy-adaptive alpha: PR #727 (lukacf) diff --git a/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/submission.json b/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/submission.json new file mode 100644 index 0000000000..db5b9ad208 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/submission.json @@ -0,0 +1,7 @@ +{ + "blurb": "11L SOTA base + multi-order n-gram backoff (2-7) with entropy-adaptive alpha. Legal score-first eval. 3-seed mean 0.6678.", + "date": "2026-03-25T00:00:00Z", + "val_loss": 1.12757285, + "val_bpb": 0.66781392, + "bytes_total": 15025238 +} diff --git a/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/train_gpt.py b/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/train_gpt.py new file mode 100644 index 0000000000..21a21b30a0 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/train_gpt.py @@ -0,0 +1,1834 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _USE_FA3 = True +except ImportError: + _USE_FA3 = False + def flash_attn_3_func(q, k, v, causal=True): + # Fallback: convert from (B,T,H,D) to (B,H,T,D) for PyTorch SDPA + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + num_heads, num_kv = q2.size(1), k2.size(1) + y = F.scaled_dot_product_attention(q2, k2, v2, attn_mask=None, is_causal=causal, + enable_gqa=(num_kv != num_heads)) + return y.transpose(1, 2) +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + + # N-gram cache (legal eval-time improvement) + ngram_cache = bool(int(os.environ.get("NGRAM_CACHE", "0"))) + ngram_alpha = float(os.environ.get("NGRAM_ALPHA", 0.40)) + ngram_min_count = int(os.environ.get("NGRAM_MIN_COUNT", 2)) + ngram_order = int(os.environ.get("NGRAM_ORDER", 7)) + ngram_min_order = int(os.environ.get("NGRAM_MIN_ORDER", 2)) + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", 4194304)) + ngram_entropy = bool(int(os.environ.get("NGRAM_ENTROPY", "1"))) + ngram_ent_base = float(os.environ.get("NGRAM_ENT_BASE", 0.05)) + ngram_ent_range = float(os.environ.get("NGRAM_ENT_RANGE", 0.55)) + ngram_ent_scale = float(os.environ.get("NGRAM_ENT_SCALE", 2.0)) + ngram_ent_thresh = float(os.environ.get("NGRAM_ENT_THRESH", 4.0)) + + # TTT (test-time training) evaluation with LoRA + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) # Epochs over all chunks (set to 1 for legal) + ttt_steps_per_chunk = int(os.environ.get("TTT_STEPS_PER_CHUNK", 3)) # Legal: train steps per chunk + ttt_k_lora = bool(int(os.environ.get("TTT_K_LORA", "1"))) + ttt_min_nll = bool(int(os.environ.get("TTT_MIN_NLL", "0"))) # Default off (illegal if epochs>1) + ttt_trajectories = int(os.environ.get("TTT_TRAJECTORIES", 1)) + ttt_per_chunk_min = bool(int(os.environ.get("TTT_PER_CHUNK_MIN", "0"))) + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None, q_lora=None, v_lora=None, k_lora=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + if q_lora is not None: + q = q + q_lora(x) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x) + if k_lora is not None: + k = k + k_lora(x) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_lora is not None: + v = v + v_lora(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None, + q_lora=None, v_lora=None, k_lora=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, + q_lora=q_lora, v_lora=v_lora, k_lora=k_lora) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + q_l = lora.q_loras[i] if lora else None + v_l = lora.v_loras[i] if lora else None + k_l = lora.k_loras[i] if (lora and lora.k_lora_enabled) else None + x = self.blocks[i](x, x0, v_embed=ve, q_lora=q_l, v_lora=v_l, k_lora=k_l) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + q_l = lora.q_loras[bi] if lora else None + v_l = lora.v_loras[bi] if lora else None + k_l = lora.k_loras[bi] if (lora and lora.k_lora_enabled) else None + x = self.blocks[bi](x, x0, v_embed=ve, q_lora=q_l, v_lora=v_l, k_lora=k_l) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + if lora is not None: + logits_proj = logits_proj + lora.lm_head_lora(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if lora is not None: + bsz, sl, V = logits.shape + return F.cross_entropy(logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + logits_flat = logits.reshape(-1, logits.size(-1)) + main_loss = F.cross_entropy(logits_flat.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation with optional n-gram cache interpolation.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + # N-gram cache setup (legal: score-first, backward-looking) + use_ngram = args.ngram_cache + if use_ngram: + val_np = val_tokens.cpu().numpy() + _n_orders = args.ngram_order - args.ngram_min_order + 1 + ctx_tables = [np.zeros((args.ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + full_tables = [np.zeros((args.ngram_buckets,), dtype=np.uint32) for _ in range(_n_orders)] + ng_mask = np.uint64(args.ngram_buckets - 1) + ng_primes = np.array( + [np.uint64(36313), np.uint64(27191), np.uint64(51647), np.uint64(81929), + np.uint64(131071), np.uint64(175447), np.uint64(209591)], dtype=np.uint64) + print(f"ngram_cache:enabled orders={args.ngram_min_order}-{args.ngram_order} " + f"entropy={args.ngram_entropy} alpha={args.ngram_alpha} buckets={args.ngram_buckets}", flush=True) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + scored_nll = nll[i, s:wlen].to(torch.float64) + # N-gram cache: mix model predictions with n-gram statistics + if use_ngram: + seg_nll_np = scored_nll.cpu().numpy() + seg_model_p = np.exp(-seg_nll_np) + n_seg = len(seg_nll_np) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + # Entropy-adaptive alpha + if args.ngram_entropy: + with torch.no_grad(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + alpha_per_tok = args.ngram_ent_base + args.ngram_ent_range / ( + 1.0 + np.exp(-args.ngram_ent_scale * (seg_ent - args.ngram_ent_thresh))) + # Precompute hashes for all orders + order_data = [] + for oi in range(_n_orders): + ctx_w = args.ngram_min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + order_data.append(None); continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * ng_primes[k % len(ng_primes)] + ctx_key = (ctx_hash & ng_mask).astype(np.int64) + tgt_np = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt_np * ng_primes[ctx_w % len(ng_primes)])) & ng_mask).astype(np.int64) + order_data.append((v_idx, ctx_key, full_key)) + # Multi-order backoff: highest order first + best_p_ng = np.full(n_seg, -1.0) + for oi in range(_n_orders - 1, -1, -1): + if order_data[oi] is None: continue + v_idx, ctx_key, full_key = order_data[oi] + ctx_counts = ctx_tables[oi][ctx_key].astype(np.float64) + full_counts = full_tables[oi][full_key].astype(np.float64) + has_match = ctx_counts >= float(args.ngram_min_count) + needs_fill = has_match & (best_p_ng[v_idx] < 0) + if needs_fill.any(): + fill_idx = v_idx[needs_fill] + p = np.minimum(full_counts[needs_fill], ctx_counts[needs_fill]) / np.maximum(ctx_counts[needs_fill], 1.0) + best_p_ng[fill_idx] = np.clip(p, 0.0, 1.0) + # Mix model probability with n-gram + has_match = best_p_ng >= 0 + if has_match.any(): + alpha = alpha_per_tok[has_match] if args.ngram_entropy else args.ngram_alpha + seg_model_p[has_match] = (1.0 - alpha) * seg_model_p[has_match] + alpha * best_p_ng[has_match] + scored_nll = torch.from_numpy(-np.log(np.clip(seg_model_p, 1e-12, 1.0))).to(dtype=torch.float64, device=device) + # Score-first: update tables AFTER scoring + for oi in range(_n_orders): + if order_data[oi] is None: continue + v_idx, ctx_key, full_key = order_data[oi] + np.add.at(ctx_tables[oi], ctx_key, 1) + np.add.at(full_tables[oi], full_key, 1) + loss_sum += scored_nll.sum() + token_count += float(seg_len) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + # Int5 for MLP, int6 for attention — saves ~1.5MB + clip = 15 if cat == "mlp" else 31 + q, s = quantize_int6_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# TTT V2 — Score-every-epoch with per-layer LR, LM rank-16, bias tuning, T=0.98 +# Based on techniques from PR #596 (0.6430 BPB) and PR #573 (1.0523 BPB) +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + def __init__(self, bsz, in_features, out_features, rank): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + self.reset() + def forward(self, x): + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) + def reset(self): + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) + self.B.zero_() + +class BatchedTTTLoRA(nn.Module): + """V3: Multi-scale LoRA with K-projection + per-layer LR + bias tuning.""" + def __init__(self, bsz, model, rank, lm_rank=16, tune_biases=False, k_lora=True): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, lm_rank) + self.q_loras = nn.ModuleList() + self.k_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + self.k_lora_enabled = k_lora + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + if k_lora: + self.k_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_k.weight.shape[0], rank)) + self.bias_params = nn.ParameterList() + if tune_biases: + for block in model.blocks: + self.bias_params.append(nn.Parameter(torch.zeros(bsz, 1, dim))) + self.bias_params.append(nn.Parameter(torch.zeros(bsz, 1, dim))) + def reset(self): + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + for p in self.bias_params: + p.data.zero_() + +def _build_ttt_optimizer(lora, base_lr): + groups = [ + {"params": list(lora.lm_head_lora.parameters()), "lr": base_lr * 2.0, "base_lr": base_lr * 2.0}, + {"params": [p for m in lora.v_loras for p in m.parameters()], "lr": base_lr * 1.5, "base_lr": base_lr * 1.5}, + {"params": [p for m in lora.q_loras for p in m.parameters()], "lr": base_lr * 0.5, "base_lr": base_lr * 0.5}, + ] + if lora.k_lora_enabled and lora.k_loras: + groups.append({"params": [p for m in lora.k_loras for p in m.parameters()], "lr": base_lr * 0.3, "base_lr": base_lr * 0.3}) + if lora.bias_params: + groups.append({"params": list(lora.bias_params), "lr": base_lr * 3.0, "base_lr": base_lr * 3.0}) + return torch.optim.Adam(groups, lr=base_lr, betas=(0.9, 0.95), eps=1e-10) + +def _reset_ttt_opt(o): + for g in o.param_groups: + for p in g['params']: + s = o.state.get(p) + if s: + s['exp_avg'].zero_(); s['exp_avg_sq'].zero_(); s['step'].fill_(0) + +def _find_docs(all_tokens): + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if i + 1 < len(bos_positions): + end += 1 + if end - start >= 2: + docs.append((start, end - start)) + return docs + +def _compute_chunk_window(ci, pred_len, num_chunks, chunk_size, eval_seq_len): + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def eval_val_ttt_lora(args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + """V2 TTT: score-every-epoch, per-layer LR, LM rank-16, bias tuning, T rescale.""" + t_start = time.perf_counter() + docs = _find_docs(val_tokens) + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size, eval_seq_len = args.ttt_chunk_size, args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank, lm_rank = args.ttt_lora_rank, int(os.environ.get("TTT_LM_RANK", 16)) + num_epochs = args.ttt_epochs + temp_rescale = float(os.environ.get("TTT_TEMP_RESCALE", 0.98)) + tune_biases = bool(int(os.environ.get("TTT_BIAS_TUNE", "1"))) + k_lora = args.ttt_k_lora + min_nll = args.ttt_min_nll + max_eval_secs = float(os.environ.get("TTT_MAX_EVAL_SECS", 550.0)) + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank, lm_rank=lm_rank, + tune_biases=tune_biases, k_lora=k_lora).to(device) + opt = _build_ttt_optimizer(lora, args.ttt_lora_lr) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + total_docs = len(rank_docs) + num_trajectories = args.ttt_trajectories + for bi in range(0, total_docs, batch_size): + if time.perf_counter() - t_start > max_eval_secs: + print(f" ttt: time limit hit at doc {bi}/{total_docs}, falling back", flush=True) + break + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + pred_lens = [dl - 1 for _, dl in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + total_train_steps = sum(max(0, nc - 1) for nc in num_chunks) * num_epochs + # Multi-trajectory: best scores across all trajectories + traj_best_loss = [torch.tensor(float('inf'), device=device, dtype=torch.float64) for _ in range(bsz)] + traj_best_bytes = [torch.zeros((), device=device, dtype=torch.float64) for _ in range(bsz)] + traj_best_toks = [0] * bsz + traj_best_avg = [float('inf')] * bsz + for traj_idx in range(num_trajectories): + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset(); _reset_ttt_opt(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank, lm_rank=lm_rank, + tune_biases=tune_biases, k_lora=k_lora).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args.ttt_lora_lr) + global_step = 0 + # Per-doc accumulators — overwritten each epoch (score-every-epoch) + doc_loss = [torch.zeros((), device=device, dtype=torch.float64) for _ in range(bsz)] + doc_bytes = [torch.zeros((), device=device, dtype=torch.float64) for _ in range(bsz)] + doc_toks = [0] * bsz + # Min-NLL: track best epoch per doc + best_doc_loss = [torch.tensor(float('inf'), device=device, dtype=torch.float64) for _ in range(bsz)] + best_doc_bytes = [torch.zeros((), device=device, dtype=torch.float64) for _ in range(bsz)] + best_doc_toks = [0] * bsz + for epoch in range(num_epochs): + for b in range(bsz): + doc_loss[b].zero_(); doc_bytes[b].zero_(); doc_toks[b] = 0 + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size = chunk_stats[1] + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)); continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + toks = val_tokens[ds + ws: ds + ws + wl + 1].to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1]; y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + # Cosine LR schedule across total train steps + if needs_train and total_train_steps > 1: + cos_mul = 0.5 * (1.0 + math.cos(math.pi * global_step / total_train_steps)) + for g in cur_opt.param_groups: + g["lr"] = g.get("base_lr", g["lr"]) * max(cos_mul, 0.1) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + # Score — accumulate into per-doc buffers (overwritten each epoch) + with torch.no_grad(): + for b in range(bsz): + if not active[b]: continue + co, cl = doc_info[b] + chunk_loss = ptl[b, co:co + cl].to(torch.float64) + if temp_rescale != 1.0: + chunk_loss = chunk_loss * temp_rescale + doc_loss[b] += chunk_loss.sum() + doc_toks[b] += cl + tgt = y[b, co:co + cl]; px = x[b, co:co + cl] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[px]).to(torch.float64) + doc_bytes[b] += tb.sum() + # Train on this chunk — multiple steps per chunk (LEGAL) + if needs_train: + steps_per_chunk = args.ttt_steps_per_chunk + for train_step in range(steps_per_chunk): + # Cosine LR decay across total steps + if total_train_steps > 1: + cos_mul = 0.5 * (1.0 + math.cos(math.pi * global_step / (total_train_steps * steps_per_chunk))) + for g in cur_opt.param_groups: + g["lr"] = g.get("base_lr", g["lr"]) * max(cos_mul, 0.05) + if train_step == 0: + # First step: use the same forward pass we scored with + train_loss = torch.zeros(bsz, device=device) + for b in range(bsz): + if ci >= num_chunks[b] - 1: continue + co, cl = doc_info[b] + if cl > 0: train_loss[b] = ptl[b, co:co + cl].mean() + cur_opt.zero_grad() + train_loss.sum().backward() + cur_opt.step() + else: + # Additional steps: new forward pass (no scoring, just training) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl2 = base_model(x, y, lora=cur_lora) + train_loss2 = torch.zeros(bsz, device=device) + for b in range(bsz): + if ci >= num_chunks[b] - 1: continue + co, cl = doc_info[b] + if cl > 0: train_loss2[b] = ptl2[b, co:co + cl].mean() + cur_opt.zero_grad() + train_loss2.sum().backward() + cur_opt.step() + global_step += 1 + # Min-NLL: after each epoch, check if this epoch is best for each doc + if min_nll: + for b in range(bsz): + if doc_toks[b] > 0: + avg_nll = doc_loss[b] / doc_toks[b] + if avg_nll < best_doc_loss[b] / max(best_doc_toks[b], 1): + best_doc_loss[b] = doc_loss[b].clone() + best_doc_bytes[b] = doc_bytes[b].clone() + best_doc_toks[b] = doc_toks[b] + # After all epochs in this trajectory: get best epoch result per doc + for b in range(bsz): + if min_nll and best_doc_toks[b] > 0: + ep_loss, ep_bytes, ep_toks = best_doc_loss[b], best_doc_bytes[b], best_doc_toks[b] + else: + ep_loss, ep_bytes, ep_toks = doc_loss[b], doc_bytes[b], doc_toks[b] + # Multi-trajectory: keep best across trajectories + if ep_toks > 0: + avg = ep_loss.item() / ep_toks + if avg < traj_best_avg[b]: + traj_best_avg[b] = avg + traj_best_loss[b] = ep_loss.clone() + traj_best_bytes[b] = ep_bytes.clone() + traj_best_toks[b] = ep_toks + # Add best trajectory's scores to global accumulators + for b in range(bsz): + if traj_best_toks[b] > 0: + loss_sum += traj_best_loss[b]; byte_sum += traj_best_bytes[b]; token_count += traj_best_toks[b] + else: + loss_sum += doc_loss[b]; byte_sum += doc_bytes[b]; token_count += doc_toks[b] + if rank == 0 and bi % (batch_size * 5) == 0: + pct = 100.0 * bi / total_docs + bpb = float((loss_sum.item() / math.log(2.0)) / max(byte_sum.item(), 1)) + elapsed = time.perf_counter() - t_start + print(f" ttt [{pct:5.1f}%] {bi}/{total_docs} docs bpb={bpb:.6f} time={elapsed:.0f}s", flush=True) + # Fallback: score remaining docs without TTT (base model only) + for bi2 in range(bi + batch_size if bi < total_docs else total_docs, total_docs, 1): + ds, dl = rank_docs[bi2] + pred_len = dl - 1 + toks = val_tokens[ds:ds + dl].to(dtype=torch.int64, device=device) + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl_fb = base_model(toks[:-1].unsqueeze(0), toks[1:].unsqueeze(0)) + lbl = ptl_fb[0].to(torch.float64) + tgt = toks[1:]; px = toks[:-1] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[px]).to(torch.float64) + loss_sum += lbl.sum(); byte_sum += tb.sum(); token_count += pred_len + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + return float(loss_sum.item() / token_count.item()), float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply EMA weights (better than SWA alone per PR#401) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + # Magnitude pruning: zero out smallest 3% of weights for better compression + prune_pct = float(os.environ.get("PRUNE_PCT", 0.03)) + if prune_pct > 0: + with torch.no_grad(): + for name, param in sd_cpu.items(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), prune_pct) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + if master_process: + log0(f"magnitude_pruning: {prune_pct*100:.0f}% threshold applied") + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # TTT evaluation (if enabled) + if args.ttt_enabled: + log0(f"final_eval_mode:ttt_lora rank:{args.ttt_lora_rank} epochs:{args.ttt_epochs} chunk:{args.ttt_chunk_size} lr:{args.ttt_lora_lr}") + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0(f"final_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"final_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/train_seed1337.log b/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/train_seed1337.log new file mode 100644 index 0000000000..05831aefeb --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/train_seed1337.log @@ -0,0 +1,104 @@ +logs/exp201_ngram_s1337.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26993756 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:43200.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9279 val_bpb:4.1031 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9282 train_time:618ms step_avg:617.56ms +step:2/20000 train_loss:8.5825 train_time:1156ms step_avg:577.81ms +step:3/20000 train_loss:7.8095 train_time:1694ms step_avg:564.66ms +step:4/20000 train_loss:7.2556 train_time:2231ms step_avg:557.86ms +step:5/20000 train_loss:7.0415 train_time:2769ms step_avg:553.78ms +step:6/20000 train_loss:6.8429 train_time:3308ms step_avg:551.30ms +step:7/20000 train_loss:6.7365 train_time:3849ms step_avg:549.93ms +step:8/20000 train_loss:6.6864 train_time:4390ms step_avg:548.78ms +step:9/20000 train_loss:6.4372 train_time:4928ms step_avg:547.57ms +step:10/20000 train_loss:6.1432 train_time:5467ms step_avg:546.66ms +step:500/20000 train_loss:2.3320 train_time:270865ms step_avg:541.73ms +step:1000/20000 train_loss:2.2768 train_time:541463ms step_avg:541.46ms +step:1500/20000 train_loss:2.1457 train_time:811844ms step_avg:541.23ms +step:2000/20000 train_loss:2.0600 train_time:1082290ms step_avg:541.15ms +step:2500/20000 train_loss:2.1032 train_time:1352739ms step_avg:541.10ms +step:3000/20000 train_loss:2.0873 train_time:1623086ms step_avg:541.03ms +step:3500/20000 train_loss:2.0807 train_time:1893399ms step_avg:540.97ms +step:4000/20000 train_loss:2.1571 train_time:2164005ms step_avg:541.00ms +step:4000/20000 val_loss:2.0695 val_bpb:1.2257 train_time:2164007ms step_avg:541.00ms +step:4500/20000 train_loss:2.1634 train_time:2434638ms step_avg:541.03ms +step:5000/20000 train_loss:2.0856 train_time:2704908ms step_avg:540.98ms +step:5500/20000 train_loss:2.1022 train_time:2975147ms step_avg:540.94ms +step:6000/20000 train_loss:2.0116 train_time:3245375ms step_avg:540.90ms +step:6500/20000 train_loss:2.1377 train_time:3515568ms step_avg:540.86ms +step:7000/20000 train_loss:1.9542 train_time:3785933ms step_avg:540.85ms +step:7500/20000 train_loss:2.0256 train_time:4055984ms step_avg:540.80ms +step:8000/20000 train_loss:1.9951 train_time:4326048ms step_avg:540.76ms +step:8000/20000 val_loss:2.0371 val_bpb:1.2065 train_time:4326050ms step_avg:540.76ms +step:8500/20000 train_loss:1.9758 train_time:4596263ms step_avg:540.74ms +step:9000/20000 train_loss:2.0269 train_time:4866316ms step_avg:540.70ms +step:9500/20000 train_loss:2.1462 train_time:5136418ms step_avg:540.68ms +step:10000/20000 train_loss:2.0613 train_time:5406524ms step_avg:540.65ms +step:10500/20000 train_loss:2.0502 train_time:5676663ms step_avg:540.63ms +step:11000/20000 train_loss:2.0141 train_time:5946798ms step_avg:540.62ms +step:11500/20000 train_loss:2.0064 train_time:6217025ms step_avg:540.61ms +step:12000/20000 train_loss:2.0295 train_time:6486952ms step_avg:540.58ms +step:12000/20000 val_loss:2.0275 val_bpb:1.2008 train_time:6486954ms step_avg:540.58ms +step:12500/20000 train_loss:2.0224 train_time:6757112ms step_avg:540.57ms +step:13000/20000 train_loss:1.9818 train_time:7026988ms step_avg:540.54ms +step:13500/20000 train_loss:2.0413 train_time:7296738ms step_avg:540.50ms +step:14000/20000 train_loss:1.9768 train_time:7566534ms step_avg:540.47ms +step:14500/20000 train_loss:2.0631 train_time:7836297ms step_avg:540.43ms +step:15000/20000 train_loss:2.0325 train_time:8106115ms step_avg:540.41ms +step:15500/20000 train_loss:2.0119 train_time:8375813ms step_avg:540.38ms +step:16000/20000 train_loss:2.0046 train_time:8645576ms step_avg:540.35ms +step:16000/20000 val_loss:2.0212 val_bpb:1.1971 train_time:8645577ms step_avg:540.35ms +step:16500/20000 train_loss:2.1384 train_time:8915423ms step_avg:540.33ms +step:17000/20000 train_loss:2.0131 train_time:9185155ms step_avg:540.30ms +step:17500/20000 train_loss:2.0102 train_time:9454899ms step_avg:540.28ms +step:18000/20000 train_loss:2.0566 train_time:9724688ms step_avg:540.26ms +step:18500/20000 train_loss:1.9100 train_time:9994491ms step_avg:540.24ms +step:19000/20000 train_loss:1.9045 train_time:10264313ms step_avg:540.23ms +step:19500/20000 train_loss:2.1180 train_time:10533930ms step_avg:540.20ms +step:20000/20000 train_loss:2.0187 train_time:10803726ms step_avg:540.19ms +step:20000/20000 val_loss:2.0128 val_bpb:1.1921 train_time:10803728ms step_avg:540.19ms +peak memory allocated: 20915 MiB reserved: 22058 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9365 val_bpb:1.1469 eval_time:12005ms +Serialized model: 106178569 bytes +Code size: 92362 bytes +magnitude_pruning: 3% threshold applied +Serialized model int6+zstd: 14585300 bytes +Total submission size int6+zstd: 14677662 bytes +Total submission size int8+zlib: 14677662 bytes +final_int6_roundtrip val_loss:1.9526 val_bpb:1.1565 eval_time:40649ms +final_int6_roundtrip_exact val_loss:1.95262725 val_bpb:1.15645584 +ngram_cache:enabled orders=2-7 entropy=True alpha=0.4 buckets=4194304 +final_int6_sliding_window val_loss:1.1271 val_bpb:0.6676 stride:64 eval_time:809175ms +final_int6_sliding_window_exact val_loss:1.12713346 val_bpb:0.66755369 +final_int8_zlib_roundtrip_exact val_loss:1.12713346 val_bpb:0.66755369 diff --git a/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/train_seed42.log b/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/train_seed42.log new file mode 100644 index 0000000000..1d8c211b26 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/train_seed42.log @@ -0,0 +1,104 @@ +logs/exp200_ngram_cache.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26993756 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:43200.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9301 val_bpb:4.1044 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9309 train_time:603ms step_avg:603.12ms +step:2/20000 train_loss:8.6493 train_time:1157ms step_avg:578.37ms +step:3/20000 train_loss:7.8514 train_time:1699ms step_avg:566.44ms +step:4/20000 train_loss:7.2834 train_time:2241ms step_avg:560.32ms +step:5/20000 train_loss:7.0074 train_time:2783ms step_avg:556.69ms +step:6/20000 train_loss:6.9041 train_time:3326ms step_avg:554.31ms +step:7/20000 train_loss:6.7726 train_time:3894ms step_avg:556.31ms +step:8/20000 train_loss:6.6373 train_time:4436ms step_avg:554.56ms +step:9/20000 train_loss:6.4245 train_time:4983ms step_avg:553.62ms +step:10/20000 train_loss:6.1292 train_time:5525ms step_avg:552.51ms +step:500/20000 train_loss:2.3396 train_time:270625ms step_avg:541.25ms +step:1000/20000 train_loss:2.2791 train_time:541391ms step_avg:541.39ms +step:1500/20000 train_loss:2.1475 train_time:811343ms step_avg:540.90ms +step:2000/20000 train_loss:2.0604 train_time:1081921ms step_avg:540.96ms +step:2500/20000 train_loss:2.1085 train_time:1351802ms step_avg:540.72ms +step:3000/20000 train_loss:2.0915 train_time:1621725ms step_avg:540.57ms +step:3500/20000 train_loss:2.0830 train_time:1891246ms step_avg:540.36ms +step:4000/20000 train_loss:2.1570 train_time:2160652ms step_avg:540.16ms +step:4000/20000 val_loss:2.0707 val_bpb:1.2264 train_time:2160654ms step_avg:540.16ms +step:4500/20000 train_loss:2.1628 train_time:2429986ms step_avg:540.00ms +step:5000/20000 train_loss:2.0854 train_time:2699298ms step_avg:539.86ms +step:5500/20000 train_loss:2.1029 train_time:2968583ms step_avg:539.74ms +step:6000/20000 train_loss:2.0141 train_time:3237877ms step_avg:539.65ms +step:6500/20000 train_loss:2.1395 train_time:3507259ms step_avg:539.58ms +step:7000/20000 train_loss:1.9550 train_time:3776563ms step_avg:539.51ms +step:7500/20000 train_loss:2.0267 train_time:4045674ms step_avg:539.42ms +step:8000/20000 train_loss:1.9945 train_time:4314790ms step_avg:539.35ms +step:8000/20000 val_loss:2.0376 val_bpb:1.2068 train_time:4314808ms step_avg:539.35ms +step:8500/20000 train_loss:1.9782 train_time:4584007ms step_avg:539.29ms +step:9000/20000 train_loss:2.0280 train_time:4853082ms step_avg:539.23ms +step:9500/20000 train_loss:2.1470 train_time:5122372ms step_avg:539.20ms +step:10000/20000 train_loss:2.0636 train_time:5391496ms step_avg:539.15ms +step:10500/20000 train_loss:2.0528 train_time:5660642ms step_avg:539.11ms +step:11000/20000 train_loss:2.0152 train_time:5929946ms step_avg:539.09ms +step:11500/20000 train_loss:2.0092 train_time:6199002ms step_avg:539.04ms +step:12000/20000 train_loss:2.0316 train_time:6468122ms step_avg:539.01ms +step:12000/20000 val_loss:2.0294 val_bpb:1.2019 train_time:6468123ms step_avg:539.01ms +step:12500/20000 train_loss:2.0236 train_time:6737373ms step_avg:538.99ms +step:13000/20000 train_loss:1.9818 train_time:7006553ms step_avg:538.97ms +step:13500/20000 train_loss:2.0425 train_time:7275658ms step_avg:538.94ms +step:14000/20000 train_loss:1.9773 train_time:7544858ms step_avg:538.92ms +step:14500/20000 train_loss:2.0653 train_time:7814094ms step_avg:538.90ms +step:15000/20000 train_loss:2.0345 train_time:8083204ms step_avg:538.88ms +step:15500/20000 train_loss:2.0136 train_time:8352386ms step_avg:538.86ms +step:16000/20000 train_loss:2.0028 train_time:8621580ms step_avg:538.85ms +step:16000/20000 val_loss:2.0209 val_bpb:1.1969 train_time:8621584ms step_avg:538.85ms +step:16500/20000 train_loss:2.1396 train_time:8890854ms step_avg:538.84ms +step:17000/20000 train_loss:2.0132 train_time:9159873ms step_avg:538.82ms +step:17500/20000 train_loss:2.0131 train_time:9429056ms step_avg:538.80ms +step:18000/20000 train_loss:2.0574 train_time:9698189ms step_avg:538.79ms +step:18500/20000 train_loss:1.9104 train_time:9967771ms step_avg:538.80ms +step:19000/20000 train_loss:1.9080 train_time:10236898ms step_avg:538.78ms +step:19500/20000 train_loss:2.1203 train_time:10505974ms step_avg:538.77ms +step:20000/20000 train_loss:2.0194 train_time:10775230ms step_avg:538.76ms +step:20000/20000 val_loss:2.0140 val_bpb:1.1928 train_time:10775237ms step_avg:538.76ms +peak memory allocated: 20915 MiB reserved: 22058 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9375 val_bpb:1.1475 eval_time:11941ms +Serialized model: 106178569 bytes +Code size: 92362 bytes +magnitude_pruning: 3% threshold applied +Serialized model int6+zstd: 14932876 bytes +Total submission size int6+zstd: 15025238 bytes +Total submission size int8+zlib: 15025238 bytes +final_int6_roundtrip val_loss:1.9548 val_bpb:1.1577 eval_time:41643ms +final_int6_roundtrip_exact val_loss:1.95479477 val_bpb:1.15773957 +ngram_cache:enabled orders=2-7 entropy=True alpha=0.4 buckets=4194304 +final_int6_sliding_window val_loss:1.1266 val_bpb:0.6672 stride:64 eval_time:816905ms +final_int6_sliding_window_exact val_loss:1.12659837 val_bpb:0.66723678 +final_int8_zlib_roundtrip_exact val_loss:1.12659837 val_bpb:0.66723678 diff --git a/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/train_seed7.log b/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/train_seed7.log new file mode 100644 index 0000000000..a654b0483c --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_NgramCache_EntropyAdaptive_0.6672/train_seed7.log @@ -0,0 +1,104 @@ +logs/exp202_ngram_s7.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26993756 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:43200.000 +seed:7 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9297 val_bpb:4.1041 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9305 train_time:624ms step_avg:624.05ms +step:2/20000 train_loss:8.8495 train_time:1162ms step_avg:580.83ms +step:3/20000 train_loss:7.9637 train_time:1699ms step_avg:566.25ms +step:4/20000 train_loss:7.2075 train_time:2237ms step_avg:559.26ms +step:5/20000 train_loss:6.9672 train_time:2777ms step_avg:555.34ms +step:6/20000 train_loss:6.9172 train_time:3314ms step_avg:552.26ms +step:7/20000 train_loss:6.7610 train_time:3851ms step_avg:550.08ms +step:8/20000 train_loss:6.6102 train_time:4394ms step_avg:549.25ms +step:9/20000 train_loss:6.3989 train_time:4937ms step_avg:548.52ms +step:10/20000 train_loss:6.1119 train_time:5479ms step_avg:547.85ms +step:500/20000 train_loss:2.3362 train_time:271838ms step_avg:543.68ms +step:1000/20000 train_loss:2.2811 train_time:543579ms step_avg:543.58ms +step:1500/20000 train_loss:2.1492 train_time:815283ms step_avg:543.52ms +step:2000/20000 train_loss:2.0646 train_time:1086900ms step_avg:543.45ms +step:2500/20000 train_loss:2.1085 train_time:1357621ms step_avg:543.05ms +step:3000/20000 train_loss:2.0925 train_time:1627826ms step_avg:542.61ms +step:3500/20000 train_loss:2.0847 train_time:1898002ms step_avg:542.29ms +step:4000/20000 train_loss:2.1634 train_time:2168113ms step_avg:542.03ms +step:4000/20000 val_loss:2.0744 val_bpb:1.2286 train_time:2168115ms step_avg:542.03ms +step:4500/20000 train_loss:2.1659 train_time:2438431ms step_avg:541.87ms +step:5000/20000 train_loss:2.0888 train_time:2708313ms step_avg:541.66ms +step:5500/20000 train_loss:2.1055 train_time:2978199ms step_avg:541.49ms +step:6000/20000 train_loss:2.0154 train_time:3248424ms step_avg:541.40ms +step:6500/20000 train_loss:2.1404 train_time:3518987ms step_avg:541.38ms +step:7000/20000 train_loss:1.9570 train_time:3788941ms step_avg:541.28ms +step:7500/20000 train_loss:2.0287 train_time:4058750ms step_avg:541.17ms +step:8000/20000 train_loss:1.9972 train_time:4328570ms step_avg:541.07ms +step:8000/20000 val_loss:2.0386 val_bpb:1.2074 train_time:4328572ms step_avg:541.07ms +step:8500/20000 train_loss:1.9782 train_time:4598365ms step_avg:540.98ms +step:9000/20000 train_loss:2.0300 train_time:4868229ms step_avg:540.91ms +step:9500/20000 train_loss:2.1527 train_time:5138018ms step_avg:540.84ms +step:10000/20000 train_loss:2.0674 train_time:5407859ms step_avg:540.79ms +step:10500/20000 train_loss:2.0532 train_time:5677564ms step_avg:540.72ms +step:11000/20000 train_loss:2.0174 train_time:5947330ms step_avg:540.67ms +step:11500/20000 train_loss:2.0100 train_time:6217156ms step_avg:540.62ms +step:12000/20000 train_loss:2.0330 train_time:6486913ms step_avg:540.58ms +step:12000/20000 val_loss:2.0303 val_bpb:1.2025 train_time:6486914ms step_avg:540.58ms +step:12500/20000 train_loss:2.0272 train_time:6756667ms step_avg:540.53ms +step:13000/20000 train_loss:1.9834 train_time:7026675ms step_avg:540.51ms +step:13500/20000 train_loss:2.0456 train_time:7296430ms step_avg:540.48ms +step:14000/20000 train_loss:1.9803 train_time:7565998ms step_avg:540.43ms +step:14500/20000 train_loss:2.0682 train_time:7835838ms step_avg:540.40ms +step:15000/20000 train_loss:2.0372 train_time:8105653ms step_avg:540.38ms +step:15500/20000 train_loss:2.0160 train_time:8375090ms step_avg:540.33ms +step:16000/20000 train_loss:2.0072 train_time:8644631ms step_avg:540.29ms +step:16000/20000 val_loss:2.0236 val_bpb:1.1985 train_time:8644632ms step_avg:540.29ms +step:16500/20000 train_loss:2.1418 train_time:8914403ms step_avg:540.27ms +step:17000/20000 train_loss:2.0150 train_time:9184175ms step_avg:540.25ms +step:17500/20000 train_loss:2.0151 train_time:9453975ms step_avg:540.23ms +step:18000/20000 train_loss:2.0605 train_time:9723683ms step_avg:540.20ms +step:18500/20000 train_loss:1.9126 train_time:9993329ms step_avg:540.18ms +step:19000/20000 train_loss:1.9092 train_time:10263212ms step_avg:540.17ms +step:19500/20000 train_loss:2.1220 train_time:10532798ms step_avg:540.14ms +step:20000/20000 train_loss:2.0216 train_time:10802622ms step_avg:540.13ms +step:20000/20000 val_loss:2.0156 val_bpb:1.1937 train_time:10802623ms step_avg:540.13ms +peak memory allocated: 20915 MiB reserved: 22058 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9396 val_bpb:1.1487 eval_time:12061ms +Serialized model: 106178569 bytes +Code size: 92362 bytes +magnitude_pruning: 3% threshold applied +Serialized model int6+zstd: 14803240 bytes +Total submission size int6+zstd: 14895602 bytes +Total submission size int8+zlib: 14895602 bytes +final_int6_roundtrip val_loss:1.9559 val_bpb:1.1584 eval_time:39263ms +final_int6_roundtrip_exact val_loss:1.95585418 val_bpb:1.15836702 +ngram_cache:enabled orders=2-7 entropy=True alpha=0.4 buckets=4194304 +final_int6_sliding_window val_loss:1.1290 val_bpb:0.6687 stride:64 eval_time:801176ms +final_int6_sliding_window_exact val_loss:1.12898672 val_bpb:0.66865130 +final_int8_zlib_roundtrip_exact val_loss:1.12898672 val_bpb:0.66865130