diff --git a/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/README.md b/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/README.md new file mode 100644 index 0000000000..cae4a6ea65 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/README.md @@ -0,0 +1,140 @@ +# Record: SP4096 + Depth Recurrence + Parallel Residuals + QK-Gain + Brotli + +**val_bpb: 1.1020** (3-seed mean, std 0.0011) | **~15.88 MB max** | 8xH100 SXM, 600s | No TTT + +**Improvement over current merged SOTA ([PR #1019](https://github.com/openai/parameter-golf/pull/1019), 1.1147 BPB):** -0.0127 BPB / -0.0088 nats (Welch t=-18.37, df=2.38, p<0.001) + +## Results + +| Seed | Steps | ms/step | **Sliding Window BPB** | Model Bytes | Total Artifact | +|------|-------|---------|------------------------|-------------|----------------| +| 42 | 5,733 | 104.67 | **1.10327** | 15,748,095 | 15,824,545 | +| 314 | 5,945 | 100.94 | **1.10181** | 15,792,991 | 15,869,441 | +| 999 | 5,936 | 101.10 | **1.10102** | 15,799,271 | 15,875,721 | +| **Mean** | | | **1.10203** | | | + +Spread across seeds: 0.0023 BPB (very tight). All 3 seeds fit under 16MB with >=124KB margin. + +## Tokenizer Change: BPB Correctness Proof + +This submission uses a SentencePiece 4096 BPE tokenizer (`fineweb_4096_bpe.model`) instead of the baseline SP1024. Per competition rules, we provide detailed proof that val_bpb is correctly calculated. + +**How BPB is computed in this script:** + +The `val_bpb` metric is computed by the same `sliding_window_bpb()` function used by all submissions in this repo. The function: + +1. Evaluates cross-entropy loss in nats per token over the full validation set using a sliding window (stride=64) +2. Counts the total number of bytes in the validation text by summing `token_byte_lengths[token_id]` for each token +3. Computes `BPB = total_nats / (total_bytes * ln(2))` + +The `token_byte_lengths` lookup table is built by `build_sentencepiece_luts()`, which inspects each token's UTF-8 byte length via `sp.id_to_piece(token_id)`. This is independent of vocabulary size — a token that represents "the" is 3 bytes whether the vocab is 1024 or 4096. + +**Key invariant:** The total byte count of the validation set is identical regardless of tokenizer, because every tokenizer produces a lossless segmentation of the same byte sequence. More tokens (SP1024) or fewer tokens (SP4096) — the bytes sum is the same. Therefore BPB is a fair cross-tokenizer comparison. + +**Verification from logs:** The validation set has `tokens:45508608` SP4096 tokens. At ~3.32 bytes/token average, this covers the same ~151M byte validation set used by SP1024 submissions (which have ~131M tokens at ~1.15 bytes/token). The per-token cross-entropy is higher with SP4096 (2.54 nats vs 1.88 nats) because each token covers more bytes, but the per-byte rate (BPB) is directly comparable. + +--- + +## What Changed vs PR #1019 + +This submission replaces the SP1024 + BigramHash + LZMA stack with a SP4096-native architecture that gets more capacity from the larger vocabulary and recurrent/parallel techniques instead of explicit bigram features. + +### 1. SP4096 Tokenizer + MLP 4x (from SP1024 + MLP 3x) + +Switching to a 4096-token SentencePiece vocabulary with 4x MLP multiplier increases model capacity from ~27M to 34.4M parameters. The larger vocabulary captures more subword patterns natively, eliminating the need for BigramHash (which compresses 3.4x worse per parameter with SP4096). + +### 2. Depth Recurrence (Layers 4-5 from Step 3000) + +After step 3000, layers 4 and 5 are re-executed, effectively giving the model 13 logical layers for the cost of 11 layers' parameters. This adds zero parameters — it's purely a compute-time technique that trades ~10% wall-clock time for improved representation depth. Source: [PR #1260](https://github.com/openai/parameter-golf/pull/1260) ablation, estimated -0.0035 BPB. + +### 3. Parallel Residuals (Layer 7+) + +From layer 7 onward, the MLP and attention outputs are merged through a learned `lane_merge` scalar and `resid_mix_mlp` vector per layer (~20KB raw, ~3-5KB compressed). This allows the model to balance attention vs MLP contributions dynamically. Source: [PR #1289](https://github.com/openai/parameter-golf/pull/1289), estimated -0.0035 BPB. + +### 4. QK-Gain 5.0 + +Initializes query and key projections with 5x scale, sharpening attention from the start of training without any parameter cost. Source: [PR #1217](https://github.com/openai/parameter-golf/pull/1217) (45 experiments), estimated -0.001 BPB. + +### 5. MuonEq-R Optimizer + +Row-norm normalization before Newton-Schulz iteration in Muon. ~15 lines of code, zero parameter cost, minor but consistent improvement. Source: [PR #1334](https://github.com/openai/parameter-golf/pull/1334). + +### 6. ADAM_WD=0.090 + GPTQ Tuning + +Increased Adam weight decay from 0.02 to 0.090 (matching Muon WD). GPTQ calibration increased from 64 to 128 AR self-generated sequences for denser Hessian estimates with the larger SP4096 model. Dampening factor tuned to 0.01. + +### 7. Brotli Compression (from LZMA) + +SP4096 int6 weights compress better under Brotli than LZMA. This switch recovers the size headroom that BigramHash removal freed up. + +### Dropped vs PR #1019 + +| Removed | Reason | +|---------|--------| +| BigramHash 3072x112 | Compresses 3.4x worse per param with SP4096, net size-negative | +| TrigramHash | Same compression issue with SP4096 | +| LZMA preset=9 | Brotli compresses SP4096 int6 weights better | +| TTT | Neutral or negative on this stack (25 failed attempts, [PR #756](https://github.com/openai/parameter-golf/pull/756)) | + +--- + +## Architecture + +| Component | Setting | Source | +|-----------|---------|--------| +| Layers | 11 (512d, 8 GQA heads, 4 KV heads) | Baseline | +| MLP | **4x** (2048) with LeakyReLU(0.5)^2 | [#493](https://github.com/openai/parameter-golf/pull/493) @parinzee | +| Tokenizer | **SentencePiece 4096** | [#1334](https://github.com/openai/parameter-golf/pull/1334) | +| Attention | XSA on all 11 layers | [#478](https://github.com/openai/parameter-golf/pull/478) @gowtham0992 | +| Depth Recurrence | **Layers 4-5 from step 3000** | [#1260](https://github.com/openai/parameter-golf/pull/1260) | +| Parallel Residuals | **Layer 7+ with learned merge** | [#1289](https://github.com/openai/parameter-golf/pull/1289) | +| QK-Gain | **5.0** | [#1217](https://github.com/openai/parameter-golf/pull/1217) | +| Optimizer | Parallel Muon + **MuonEq-R** + Parameter Banking | [#399](https://github.com/openai/parameter-golf/pull/399), [#1334](https://github.com/openai/parameter-golf/pull/1334) | +| RoPE | Partial (16/64 dims) | [#315](https://github.com/openai/parameter-golf/pull/315) @jfprincz | +| LN Scale | 1/sqrt(layer+1) | [#315](https://github.com/openai/parameter-golf/pull/315) @jfprincz | +| VE128 | Layers 9-10 | [#374](https://github.com/openai/parameter-golf/pull/374) @unnir | +| SmearGate | Position-mixing gate | [#65](https://github.com/openai/parameter-golf/pull/65) @aquariouseworkman | +| U-Net skips | Encoder-decoder connections | [#289](https://github.com/openai/parameter-golf/pull/289) | +| Weight avg | EMA(0.997) + Tight SWA(every 50) | [#401](https://github.com/openai/parameter-golf/pull/401) @newjordan | +| Quantization | Full Hessian GPTQ int6 (AR self-gen, **128 batch**) | [#535](https://github.com/openai/parameter-golf/pull/535) @raahilshah | +| Compression | **Brotli** | **This work** | +| Warmdown | 4000 iterations | [#364](https://github.com/openai/parameter-golf/pull/364) @shikhar1729 | +| Late QAT | STE at LR scale < 0.15 | [#286](https://github.com/openai/parameter-golf/pull/286) @chris-buckley | +| Selective pruning | +/-1 values by reconstruction error | [#609](https://github.com/openai/parameter-golf/pull/609) @saml212 | +| Flash Attention 3 | Hopper warp-specialized kernels | [#122](https://github.com/openai/parameter-golf/pull/122) @mtybadger | + +## Requirements + +**Flash Attention 3 (Hopper) is required.** + +```bash +pip install --break-system-packages flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291 +pip install sentencepiece zstandard brotli +python3 -c "from flash_attn_interface import flash_attn_func; import sentencepiece, zstandard, brotli; print('deps OK')" +``` + +## Run Command + +```bash +VOCAB_SIZE=4096 MLP_MULT=4.0 QK_GAIN_INIT=5.0 MUON_EQ_R=1 \ +RECUR_LAYERS="4,5" RECUR_START_STEP=3000 PARALLEL_START_LAYER=7 \ +MUON_WD=0.090 ADAM_WD=0.090 WARMDOWN_ITERS=4000 \ +GPTQ_CALIB_BATCHES=128 GPTQ_DAMP=0.01 \ +BIGRAM_VOCAB_SIZE=0 TRIGRAM=0 TARGET_MB=15.9 SEED=42 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Lineage + +``` +PR #1019 (Merged SOTA, 1.1147 BPB) -- SP1024 + BigramHash + LZMA + +-- This work replaces with: + +-- SP4096 + MLP 4x (native vocabulary capacity, no bigram needed) + +-- Depth recurrence layers 4-5 from step 3000 (from #1260) + +-- Parallel residuals layer 7+ with learned merge (from #1289) + +-- QK-Gain 5.0 (from #1217) + +-- MuonEq-R optimizer (from #1334) + +-- ADAM_WD=0.090, GPTQ 128-batch calibration, damp=0.01 + +-- Brotli compression (better for SP4096 int6) + +-- Guided by 37 GPU runs (~$266) and PR #670 negative results +``` diff --git a/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/requirements.txt b/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/requirements.txt new file mode 100644 index 0000000000..0021a93284 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/requirements.txt @@ -0,0 +1,4 @@ +flash_attn_3 +sentencepiece +zstandard +brotli diff --git a/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/run_pgolf.sh b/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/run_pgolf.sh new file mode 100644 index 0000000000..f9e0ee84b2 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/run_pgolf.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Parameter Golf — Phase 1b: SP4096 + Depth Recurrence + Parallel Residuals + QK-Gain + Brotli +# 3-seed validated 2026-04-05: mean 1.1020 BPB (seeds 42, 314, 999) +# Usage: bash run_pgolf_phase1b.sh [seed] + +SEED="${1:-42}" + +# === SP4096 ARCHITECTURE === +export VOCAB_SIZE=4096 +export MLP_MULT=4.0 +export XSA_LAST_N=11 + +# === PR #1334 TECHNIQUES === +export QK_GAIN_INIT=5.0 +export MUON_EQ_R=1 +export RECUR_LAYERS="4,5" +export RECUR_START_STEP=3000 +export PARALLEL_START_LAYER=7 + +# === OPTIMIZER === +export MATRIX_LR=0.02 +export SCALAR_LR=0.02 +export TIED_EMBED_LR=0.03 +export MUON_WD=0.090 +export ADAM_WD=0.090 +export WARMDOWN_ITERS=4000 + +# === GPTQ (AR self-gen, Brotli) === +export GPTQ_CALIB_BATCHES=128 +export GPTQ_DAMP=0.01 +export LATE_QAT_THRESHOLD=0.15 + +# === DISABLED (don't fit SP4096 or confirmed dead) === +export BIGRAM_VOCAB_SIZE=0 +export BIGRAM_DIM=0 +export TRIGRAM=0 +export HADAMARD_ROTATION=0 +export SOFT_ROUND_QAT=0 +export PREQUANT_TTT=0 +export MIXED_BITWIDTH=0 +export NGRAM_ENABLED=0 +export TTT_ENABLED=0 + +# === DATA PATHS (SP4096) === +export DATA_PATH=./data/datasets/fineweb10B_sp4096 +export TOKENIZER_PATH=./data/tokenizers/fineweb_4096_bpe.model + +# === PROVEN STACK === +export SWA_ENABLED=1 +export VE_ENABLED=1 +export LN_SCALE=1 +export TARGET_MB=15.9 +export SEED=$SEED + +# === FULL RUN (10 min wallclock) === +export ITERATIONS=20000 +export MAX_WALLCLOCK_SECONDS=600 +export VAL_LOSS_EVERY=4000 + +echo "=== Phase 1b: SP4096 + DepthRecur + ParallelResid + QK-Gain + Brotli ===" +echo "Seed: $SEED | Vocab: $VOCAB_SIZE | MLP: ${MLP_MULT}x | ADAM_WD: $ADAM_WD" +echo "DepthRecur: layers $RECUR_LAYERS from step $RECUR_START_STEP" +echo "ParallelResid: from layer $PARALLEL_START_LAYER | QK-Gain: $QK_GAIN_INIT" +echo "" + +torchrun --standalone --nproc_per_node=8 train_gpt.py 2>&1 | tee "run_phase1b_seed${SEED}_$(date +%Y%m%d_%H%M%S).log" diff --git a/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/submission.json b/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/submission.json new file mode 100644 index 0000000000..8b21f44e60 --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/submission.json @@ -0,0 +1,52 @@ +{ + "author": "Its-Just-Crump", + "github_id": "Its-Just-Crump", + "name": "SP4096 + Depth Recurrence + Parallel Residuals + QK-Gain + AR Self-Gen GPTQ + Brotli", + "blurb": "11L/512d GQA with SP4096 tokenizer, MLP 4x (34.4M params), depth recurrence (layers 4-5 from step 3000), parallel residuals (layer 7+), QK-Gain 5.0, MuonEq-R, Full Hessian GPTQ with AR self-gen calibration (128 seqs x 2048 tokens), Brotli compression. No BigramHash. 3-seed exact mean: 1.10203473 BPB, beating current SOTA 1.11473509 BPB by 0.01270 BPB / 0.00880 nats (Welch t=-18.37, df=2.38, p<0.001).", + "date": "2026-04-05", + "track": "10min_16mb", + "val_loss": 2.53578534, + "val_bpb": 1.10203473, + "val_loss_std": 0.00263202, + "val_bpb_std": 0.00114386, + "seeds": [42, 314, 999], + "seed_results": { + "42": { + "val_loss": 2.53863636, + "val_bpb": 1.10327377, + "artifact_bytes": 15824545, + "steps": 5733, + "step_avg_ms": 104.67 + }, + "314": { + "val_loss": 2.53527157, + "val_bpb": 1.10181145, + "artifact_bytes": 15869441, + "steps": 5945, + "step_avg_ms": 100.94 + }, + "999": { + "val_loss": 2.53344809, + "val_bpb": 1.10101898, + "artifact_bytes": 15875721, + "steps": 5936, + "step_avg_ms": 101.10 + } + }, + "comparison_baseline_pr": 1019, + "delta_vs_baseline_bpb": -0.01270036, + "delta_vs_baseline_nats": -0.00880322, + "t_statistic": -18.3720, + "welch_df": 2.3794, + "artifact_bytes_mean": 15856569, + "artifact_bytes_max": 15875721, + "bytes_total": 15875721, + "train_steps_mean": 5871.33, + "step_avg_ms_mean": 102.24, + "hardware": "8xH100 80GB SXM", + "pytorch_version": "2.9.1+cu128", + "cuda_version": "12.8", + "flash_attn_version": "2.8.3 (FA3 Hopper kernels)", + "calibration": "AR self-generated (128 seqs x 2048 tokens, temp=0.8, no external data)", + "technique_summary": "SP4096 + MLP4x + Depth Recurrence + Parallel Residuals + QK-Gain 5.0 + MuonEq-R + AR Self-Gen GPTQ (128 batch) + Brotli" +} diff --git a/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/train_gpt.py b/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/train_gpt.py new file mode 100644 index 0000000000..92a8d4c1ab --- /dev/null +++ b/records/track_10min_16mb/2026-04-05_SP4096_DepthRecur_ParallelResid_Brotli/train_gpt.py @@ -0,0 +1,1000 @@ +from __future__ import annotations +_A3='passthrough_ctrl' +_A2='jm_lambdas' +_A1='min_count' +_A0='min_order' +_z='ng_primes' +_y='ng_mask' +_x='snap_full' +_w='snap_ctx' +_v='passthrough_orig_dtypes' +_u='dtypes' +_t='scales' +_s='quantized' +_r='per_row' +_q='scheme' +_p='torch.' +_o='momentum' +_n='shard_mom' +_m='padded_grad' +_l='fineweb_train_*.bin' +_k='type' +_j='inf' +_i='full_tables' +_h='ctx_tables' +_g='n_orders' +_f='mean' +_e='shard' +_d='0.0001' +_c='9,10' +_b='.scale' +_a='.q' +_Z='mlp_down_bank' +_Y='mlp_up_bank' +_X='kv_bank' +_W='qo_bank' +_V='attn' +_U='mlp' +_T='total_scored' +_S='total_matches' +_R='passthrough' +_Q='scale' +_P='full_update' +_O='tokens_scored' +_N='cpu' +_M='utf-8' +_L='ve' +_K=',' +_J='1' +_I='cuda' +_H='lr' +_G='params' +_F='0' +_E=.0 +_D=1. +_C=False +_B=True +_A=None +import copy,glob,io,lzma,math +try:import brotli;_HAS_BROTLI=_B +except ImportError:_HAS_BROTLI=_C +import os,random,subprocess,sys,time,uuid,zlib +from pathlib import Path +try:import zstandard;_COMPRESSOR='zstd' +except ImportError:_COMPRESSOR='zlib' +import numpy as np,sentencepiece as spm,torch,torch.distributed as dist,torch.nn.functional as F +from torch import Tensor,nn +from torch.nn.parallel import DistributedDataParallel as DDP +try:from flash_attn_interface import flash_attn_func as flash_attn_3_func;HAS_FA3=_B +except ImportError:HAS_FA3=_C +SKIP_QUANTIZE=bool(int(os.environ.get('SKIP_QUANTIZE',_F))) +SKIP_COMPILE=bool(int(os.environ.get('SKIP_COMPILE',_F))) +class Hyperparameters: + data_path=os.environ.get('DATA_PATH','./data/datasets/fineweb10B_sp1024');train_files=os.path.join(data_path,_l);val_files=os.path.join(data_path,'fineweb_val_*.bin');tokenizer_path=os.environ.get('TOKENIZER_PATH','./data/tokenizers/fineweb_1024_bpe.model');run_id=os.environ.get('RUN_ID',str(uuid.uuid4()));seed=int(os.environ.get('SEED',1337));val_batch_size=int(os.environ.get('VAL_BATCH_SIZE',524288));val_loss_every=int(os.environ.get('VAL_LOSS_EVERY',4000));train_log_every=int(os.environ.get('TRAIN_LOG_EVERY',500));iterations=int(os.environ.get('ITERATIONS',20000));warmdown_iters=int(os.environ.get('WARMDOWN_ITERS',3500));warmup_steps=int(os.environ.get('WARMUP_STEPS',20));train_batch_tokens=int(os.environ.get('TRAIN_BATCH_TOKENS',786432));train_seq_len=int(os.environ.get('TRAIN_SEQ_LEN',2048));eval_seq_len=int(os.environ.get('EVAL_SEQ_LEN',2048));max_wallclock_seconds=float(os.environ.get('MAX_WALLCLOCK_SECONDS',6e2));qk_gain_init=float(os.environ.get('QK_GAIN_INIT',1.5));vocab_size=int(os.environ.get('VOCAB_SIZE',1024));num_layers=int(os.environ.get('NUM_LAYERS',11));num_kv_heads=int(os.environ.get('NUM_KV_HEADS',4));model_dim=int(os.environ.get('MODEL_DIM',512));num_heads=int(os.environ.get('NUM_HEADS',8));mlp_mult=float(os.environ.get('MLP_MULT',3.));tie_embeddings=bool(int(os.environ.get('TIE_EMBEDDINGS',_J)));rope_base=float(os.environ.get('ROPE_BASE',1e4));logit_softcap=float(os.environ.get('LOGIT_SOFTCAP',3e1));embed_lr=float(os.environ.get('EMBED_LR',.6));head_lr=float(os.environ.get('HEAD_LR',.008));tied_embed_lr=float(os.environ.get('TIED_EMBED_LR',.035));tied_embed_init_std=float(os.environ.get('TIED_EMBED_INIT_STD',.005));matrix_lr=float(os.environ.get('MATRIX_LR',.025));scalar_lr=float(os.environ.get('SCALAR_LR',.025));muon_momentum=float(os.environ.get('MUON_MOMENTUM',.99));muon_backend_steps=int(os.environ.get('MUON_BACKEND_STEPS',5));muon_momentum_warmup_start=float(os.environ.get('MUON_MOMENTUM_WARMUP_START',.92));muon_momentum_warmup_steps=int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS',1500));muon_eq_r=bool(int(os.environ.get('MUON_EQ_R',_J)));recur_layers=os.environ.get('RECUR_LAYERS','');recur_start_step=int(os.environ.get('RECUR_START_STEP',3000));parallel_start_layer=int(os.environ.get('PARALLEL_START_LAYER',-1));beta1=float(os.environ.get('BETA1',.9));beta2=float(os.environ.get('BETA2',.95));adam_eps=float(os.environ.get('ADAM_EPS',1e-08));grad_clip_norm=float(os.environ.get('GRAD_CLIP_NORM',.3));eval_stride=int(os.environ.get('EVAL_STRIDE',64));mtp_num_heads=int(os.environ.get('MTP_NUM_HEADS',0));mtp_loss_weight=float(os.environ.get('MTP_LOSS_WEIGHT',.2));muon_beta2=float(os.environ.get('MUON_BETA2',.95));swa_enabled=bool(int(os.environ.get('SWA_ENABLED',_J)));swa_every=int(os.environ.get('SWA_EVERY',50));lawa_enabled=bool(int(os.environ.get('LAWA_ENABLED',_F)));lawa_k=int(os.environ.get('LAWA_K',10));lawa_freq=int(os.environ.get('LAWA_FREQ',100));muon_wd=float(os.environ.get('MUON_WD',.04));adam_wd=float(os.environ.get('ADAM_WD',.04));qat_enabled=bool(int(os.environ.get('QAT_ENABLED',_F)));bigram_vocab_size=int(os.environ.get('BIGRAM_VOCAB_SIZE',2048));bigram_dim=int(os.environ.get('BIGRAM_DIM',128));trigram_enabled=bool(int(os.environ.get('TRIGRAM',_F)));xsa_last_n=int(os.environ.get('XSA_LAST_N',4));rope_dims=int(os.environ.get('ROPE_DIMS',16));ln_scale=bool(int(os.environ.get('LN_SCALE',_J)));dtg_enabled=bool(int(os.environ.get('DTG_ENABLED',_F)));late_qat_threshold=float(os.environ.get('LATE_QAT_THRESHOLD',.15));ve_enabled=bool(int(os.environ.get('VE_ENABLED',_J)));ve_dim=int(os.environ.get('VE_DIM',128));ve_layers=os.environ.get('VE_LAYERS',_c);gated_attention=bool(int(os.environ.get('GATED_ATTENTION',_F)));value_residual=bool(int(os.environ.get('VALUE_RESIDUAL',_F)));ttt_enabled=bool(int(os.environ.get('TTT_ENABLED',_F)));ttt_lr=float(os.environ.get('TTT_LR',.002));ttt_epochs=int(os.environ.get('TTT_EPOCHS',3));ttt_chunk_tokens=int(os.environ.get('TTT_CHUNK_TOKENS',32768));ttt_freeze_blocks=int(os.environ.get('TTT_FREEZE_BLOCKS',2));ttt_momentum=float(os.environ.get('TTT_MOMENTUM',.9));ttt_batch_seqs=int(os.environ.get('TTT_BATCH_SEQS',32));ttt_grad_clip=float(os.environ.get('TTT_GRAD_CLIP',_D));crown_q_enabled=bool(int(os.environ.get('CROWN_Q_ENABLED',_F)));crown_q_lambda=float(os.environ.get('CROWN_Q_LAMBDA','0.01'));soft_round_qat=bool(int(os.environ.get('SOFT_ROUND_QAT',_F)));ttt_optimizer=os.environ.get('TTT_OPTIMIZER','sgd');ttt_adamw_lr=float(os.environ.get('TTT_ADAMW_LR',_d));ttt_q_only=bool(int(os.environ.get('TTT_Q_ONLY',_F)));ttt_difficulty=bool(int(os.environ.get('TTT_DIFFICULTY',_F)));ttt_hard_epochs=int(os.environ.get('TTT_HARD_EPOCHS','5'));ttt_easy_epochs=int(os.environ.get('TTT_EASY_EPOCHS',_J));ttt_hard_threshold=float(os.environ.get('TTT_HARD_THRESHOLD','1.3'));ttt_easy_threshold=float(os.environ.get('TTT_EASY_THRESHOLD','1.0'));ttt_temperature=float(os.environ.get('TTT_TEMPERATURE','1.0'));ttt_nesterov=bool(int(os.environ.get('TTT_NESTEROV',_J)));ttt_lr_floor=float(os.environ.get('TTT_LR_FLOOR','0.1'));ngram_enabled=bool(int(os.environ.get('NGRAM_ENABLED',_F)));ngram_alpha=float(os.environ.get('NGRAM_ALPHA','0.40'));ngram_order=int(os.environ.get('NGRAM_ORDER','7'));ngram_min_order=int(os.environ.get('NGRAM_MIN_ORDER','2'));ngram_buckets=int(os.environ.get('NGRAM_BUCKETS','4194304')) + if ngram_buckets&ngram_buckets-1!=0:raise ValueError(f"NGRAM_BUCKETS must be a power of 2, got {ngram_buckets}") + ngram_min_count=int(os.environ.get('NGRAM_MIN_COUNT','2'));ngram_entropy=bool(int(os.environ.get('NGRAM_ENTROPY',_J)));ngram_ent_base=float(os.environ.get('NGRAM_ENT_BASE','0.05'));ngram_ent_range=float(os.environ.get('NGRAM_ENT_RANGE','0.55'));ngram_ent_scale=float(os.environ.get('NGRAM_ENT_SCALE','2.0'));ngram_ent_thresh=float(os.environ.get('NGRAM_ENT_THRESH','4.0'));ngram_jm=bool(int(os.environ.get('NGRAM_JM',_F)));ngram_per_order_centers=bool(int(os.environ.get('NGRAM_PER_ORDER_CENTERS',_F)));ngram_depth_signal=bool(int(os.environ.get('NGRAM_DEPTH_SIGNAL',_F)));ngram_depth_scale=float(os.environ.get('NGRAM_DEPTH_SCALE',_d));ngram_sync_interval=int(os.environ.get('NGRAM_SYNC_INTERVAL','50'));gptq_full_hessian=bool(int(os.environ.get('GPTQ_FULL_HESSIAN',_J)));gptq_calib_batches=int(os.environ.get('GPTQ_CALIB_BATCHES',256));gptq_block_size=int(os.environ.get('GPTQ_BLOCK_SIZE',128));gptq_damp=float(os.environ.get('GPTQ_DAMP','0.01'));mixed_bitwidth=bool(int(os.environ.get('MIXED_BITWIDTH',_F)));hadamard_rotation=bool(int(os.environ.get('HADAMARD_ROTATION',_F)));prequant_ttt=bool(int(os.environ.get('PREQUANT_TTT',_F)));prequant_ttt_lr=float(os.environ.get('PREQUANT_TTT_LR',_d));prequant_ttt_epochs=int(os.environ.get('PREQUANT_TTT_EPOCHS','3')) +def zeropower_via_newtonschulz5(G,steps=5,eps=1e-07): + E,F,H=3.4445,-4.775,2.0315;C=G.ndim==2 + if C:G=G.unsqueeze(0) + A=G.bfloat16();D=A.size(-2)>A.size(-1) + if D:A=A.mT + A=A/(A.norm(dim=(-2,-1),keepdim=_B)+eps) + for J in range(steps):B=A@A.mT;I=F*B+H*(B@B);A=E*A+I@A + if D:A=A.mT + if C:A=A.squeeze(0) + return A +class Muon(torch.optim.Optimizer): + def __init__(A,params,lr,momentum,backend_steps,nesterov=_B,weight_decay=_E,muon_eq_r=_B):super().__init__(params,dict(lr=lr,momentum=momentum,backend_steps=backend_steps,nesterov=nesterov,weight_decay=weight_decay));A.muon_eq_r=muon_eq_r;A._built=_C + def _build(A): + A._distributed=dist.is_available()and dist.is_initialized();A._world_size=dist.get_world_size()if A._distributed else 1;A._rank=dist.get_rank()if A._distributed else 0;C=A._world_size;A._bank_meta=[] + for I in A.param_groups: + for B in I[_G]:G=B.shape[0];F=(G+C-1)//C*C;H=F//C;D=B.shape[1:];E=B.device;A._bank_meta.append({'p':B,'B':G,_m:torch.zeros(F,*D,device=E,dtype=torch.bfloat16),_e:torch.zeros(H,*D,device=E,dtype=torch.bfloat16),_n:torch.zeros(H,*D,device=E,dtype=torch.bfloat16),_P:torch.zeros(F,*D,device=E,dtype=torch.bfloat16),_Q:max(1,B.shape[-2]/B.shape[-1])**.5}) + A._bank_meta.sort(key=lambda m:-m['p'].numel());A._built=_B + def launch_reduce_scatters(A): + if not A._built:A._build() + if not A._distributed:return + A._rs_futures=[] + for B in A._bank_meta: + D=B['p'] + if D.grad is _A:A._rs_futures.append(_A);continue + C=B[_m];C[:B['B']].copy_(D.grad.bfloat16()) + if C.shape[0]>B['B']:C[B['B']:].zero_() + E=dist.reduce_scatter_tensor(B[_e],C,op=dist.ReduceOp.AVG,async_op=_B);A._rs_futures.append(E) + @torch.no_grad() + def step(self,closure=_A): + U='_rs_futures';P=closure;O='momentum_buffer';A=self;Q=_A + if P is not _A: + with torch.enable_grad():Q=P() + if not A._built:A._build() + for I in A.param_groups: + E=I[_H];R=I[_o];V=I['backend_steps'];W=I['nesterov'];F=I.get('weight_decay',_E);J=_A;B=_A;S=A._distributed and hasattr(A,U) + for(T,G)in enumerate(A._bank_meta): + H=G['p'] + if H.grad is _A:continue + if J is not _A: + J.wait();D=B['p'];M=B[_P][:B['B']] + if F>_E:D.data.mul_(_D-E*F) + D.add_(M.to(dtype=D.dtype),alpha=-E*B[_Q]) + if S and A._rs_futures[T]is not _A:A._rs_futures[T].wait();K=G[_e];L=G[_n] + else: + K=H.grad.bfloat16();N=A.state[H] + if O not in N:N[O]=torch.zeros_like(K) + L=N[O] + L.mul_(R).add_(K) + if W:C=K.add(L,alpha=R) + else:C=L + if A.muon_eq_r:X=(C*C).sum(dim=-1,keepdim=_B)+1e-07;C=C/X.sqrt() + C=zeropower_via_newtonschulz5(C,steps=V) + if S:J=dist.all_gather_into_tensor(G[_P],C,async_op=_B);B=G + else: + if F>_E:H.data.mul_(_D-E*F) + H.add_(C.to(dtype=H.dtype),alpha=-E*G[_Q]) + if J is not _A: + J.wait();D=B['p'];M=B[_P][:B['B']] + if F>_E:D.data.mul_(_D-E*F) + D.add_(M.to(dtype=D.dtype),alpha=-E*B[_Q]) + if hasattr(A,U):del A._rs_futures + return Q +def build_sentencepiece_luts(sp,vocab_size,device): + D=device;B=sp;G=int(B.vocab_size());E=max(G,vocab_size);F=np.zeros((E,),dtype=np.int16);H=np.zeros((E,),dtype=np.bool_);I=np.ones((E,),dtype=np.bool_) + for A in range(G): + if B.is_control(A)or B.is_unknown(A)or B.is_unused(A):continue + I[A]=_C + if B.is_byte(A):F[A]=1;continue + C=B.id_to_piece(A) + if C.startswith('▁'):H[A]=_B;C=C[1:] + F[A]=len(C.encode(_M)) + return torch.tensor(F,dtype=torch.int16,device=D),torch.tensor(H,dtype=torch.bool,device=D),torch.tensor(I,dtype=torch.bool,device=D) +def load_validation_tokens(pattern,seq_len): + B=pattern;A=seq_len;C=[Path(A)for A in sorted(glob.glob(B))] + if not C:raise FileNotFoundError(f"No files found for pattern: {B}") + D=torch.cat([load_data_shard(A)for A in C]).contiguous();E=(D.numel()-1)//A*A + if E<=0:raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={A}") + return D[:E+1] +def eval_val(args,model,rank,world_size,device,grad_accum_steps,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,eval_seq_len=_A): + K=val_tokens;J=grad_accum_steps;F=model;E=args;C=device;B=world_size;A=eval_seq_len or E.train_seq_len;L=E.val_batch_size//(B*J) + if L0 else _D,dtype=torch.float32);D=torch.clamp(torch.round(torch.clamp(A,-B,B)/C),-127,127).to(torch.int8).contiguous();return D,C +def quantize_state_dict_int8(state_dict): + S='baseline_tensor_bytes';R='num_nonfloat_tensors';Q='num_float_tensors';P='num_tensors';O='param_count';D='int8_payload_bytes';J={};K={};L={};E={};F={};G={};A=dict.fromkeys((O,P,Q,R,S,D),0) + for(C,T)in state_dict.items(): + B=T.detach().to(_N).contiguous();A[O]+=int(B.numel());A[P]+=1;A[S]+=tensor_nbytes(B) + if not B.is_floating_point():A[R]+=1;E[C]=B;A[D]+=tensor_nbytes(B);continue + if B.numel()<=INT8_KEEP_FLOAT_MAX_NUMEL:M=keep_float_tensor(C,B,F);E[C]=M;A[D]+=tensor_nbytes(M);continue + A[Q]+=1;N,H=quantize_float_tensor(B) + if H.ndim>0:G[C]={_q:_r,'axis':0} + J[C]=N;K[C]=H;L[C]=str(B.dtype).removeprefix(_p);A[D]+=tensor_nbytes(N)+tensor_nbytes(H) + I={'__quant_format__':'int8_clean_per_row_v1',_s:J,_t:K,_u:L,_R:E} + if G:I['qmeta']=G + if F:I[_v]=F + return I,A +def dequantize_state_dict_int8(obj): + B=obj;D={};I=B.get('qmeta',{});J=B.get(_v,{}) + for(A,E)in B[_s].items(): + G=getattr(torch,B[_u][A]);C=B[_t][A] + if I.get(A,{}).get(_q)==_r or C.ndim>0:C=C.to(dtype=torch.float32);D[A]=(E.float()*C.view(E.shape[0],*[1]*(E.ndim-1))).to(dtype=G).contiguous() + else:K=float(C.item());D[A]=(E.float()*K).to(dtype=G).contiguous() + for(A,L)in B[_R].items(): + F=L.detach().to(_N).contiguous();H=J.get(A) + if isinstance(H,str):F=F.to(dtype=getattr(torch,H)).contiguous() + D[A]=F + return D +def load_data_shard(file): + H='0: + E=A.tokens.numel()-A.pos + if E<=0:A._advance_file();continue + D=min(C,E);B.append(A.tokens[A.pos:A.pos+D]);A.pos+=D;C-=D + return B[0]if len(B)==1 else torch.cat(B) +class DistributedTokenLoader: + def __init__(A,pattern,rank,world_size,device):A.rank=rank;A.world_size=world_size;A.device=device;A.stream=TokenStream(pattern) + def next_batch(A,global_tokens,seq_len,grad_accum_steps):C=seq_len;F=global_tokens//(A.world_size*grad_accum_steps);B=F+1;G=A.stream.take(B*A.world_size);D=A.rank*B;E=G[D:D+B].to(dtype=torch.int64);H=E[:-1].reshape(-1,C);I=E[1:].reshape(-1,C);return H.to(A.device,non_blocking=_B),I.to(A.device,non_blocking=_B) +class RMSNorm(nn.Module): + def __init__(A,eps=_A):super().__init__();A.eps=eps + def forward(A,x):return F.rms_norm(x,(x.size(-1),),eps=A.eps) +class CastedLinear(nn.Linear): + _qat_enabled=_C;_soft_round_alpha=_E + def forward(B,x): + A=B.weight.to(x.dtype) + if CastedLinear._qat_enabled and B.training and A.ndim==2: + C=B.weight.float();J=C.abs().amax(dim=1);D=(J/31.).clamp_min(_D/31.);E=C/D[:,_A] + if CastedLinear._soft_round_alpha>0:H=CastedLinear._soft_round_alpha;I=torch.floor(E+.5);K=E-I;L=torch.tensor(H*.5,device=A.device,dtype=C.dtype);M=I+.5*torch.tanh(H*K)/torch.tanh(L);G=(torch.clamp(M,-31,31)*D[:,_A]).to(x.dtype);A=G + else: + with torch.no_grad():G=(torch.clamp(torch.round(E),-31,31)*D[:,_A]).to(x.dtype) + A=A+(G-A).detach() + N=B.bias.to(x.dtype)if B.bias is not _A else _A;return F.linear(x,A,N) +def restore_low_dim_params_to_fp32(module): + with torch.no_grad(): + for(B,A)in module.named_parameters(): + if(A.ndim<2 or any(A in B for A in CONTROL_TENSOR_NAME_PATTERNS))and A.dtype!=torch.float32:A.data=A.data.float() +class Rotary(nn.Module): + def __init__(A,dim,base=1e4,train_seq_len=1024,rope_dims=0):B=rope_dims;super().__init__();A.dim=dim;A.base=base;A.train_seq_len=train_seq_len;A.rope_dims=B if B>0 else dim;C=_D/base**(torch.arange(0,A.rope_dims,2,dtype=torch.float32)/A.rope_dims);A.register_buffer('inv_freq',C,persistent=_C);A._seq_len_cached=0;A._cos_cached=_A;A._sin_cached=_A + def forward(A,seq_len,device,dtype): + F=dtype;C=device;B=seq_len + if A._cos_cached is _A or A._sin_cached is _A or A._seq_len_cached!=B or A._cos_cached.device!=C: + D=A.rope_dims + if B>A.train_seq_len:H=B/A.train_seq_len;I=A.base*H**(D/(D-2));E=_D/I**(torch.arange(0,D,2,dtype=torch.float32,device=C)/D) + else:E=A.inv_freq.to(C) + J=torch.arange(B,device=C,dtype=E.dtype);G=torch.outer(J,E);A._cos_cached=G.cos()[_A,:,_A,:];A._sin_cached=G.sin()[_A,:,_A,:];A._seq_len_cached=B + return A._cos_cached.to(dtype=F),A._sin_cached.to(dtype=F) +def apply_rotary_emb(x,cos,sin,rope_dims=0): + F=sin;E=cos;A=rope_dims + if A>0 and A0 else _A;A.smear=SmearGate(B);A.num_encoder_layers=C//2;A.num_decoder_layers=C-A.num_encoder_layers;A.num_skip_weights=min(A.num_encoder_layers,A.num_decoder_layers);A.skip_weights=nn.Parameter(torch.ones(A.num_skip_weights,B,dtype=torch.float32));I=B//E;T=F*I;R=int(J*B);A.num_layers=C;A.qo_bank=nn.Parameter(torch.empty(2*C,B,B));A.kv_bank=nn.Parameter(torch.empty(2*C,T,B));A.mlp_up_bank=nn.Parameter(torch.empty(C,R,B));A.mlp_down_bank=nn.Parameter(torch.empty(C,B,R));A.blocks=nn.ModuleList([Block(B,E,F,J,L,qk_gain_init,layer_idx=A,ln_scale=ln_scale,dtg=dtg,gated_attention=gated_attention,value_residual=P)for A in range(C)]) + if H>0: + I=B//E + for S in A.blocks:S.attn.rope_dims=H;S.attn.rotary=Rotary(I,base=L,train_seq_len=1024,rope_dims=H) + A.ve_layer_indices=[int(A)for A in ve_layers.split(_K)if A.strip()]if ve_enabled else[];U=A._ve_target_dim + if A.ve_layer_indices:A.ve_shared=ValueEmbedding(D,ve_dim,U);A.ve_layer_scales=nn.ParameterList([nn.Parameter(torch.ones(1,dtype=torch.float32))for A in A.ve_layer_indices]) + else:A.ve_shared=_A;A.ve_layer_scales=nn.ParameterList() + A.value_embeds=nn.ModuleList();A.final_norm=RMSNorm();A.lm_head=_A if K else CastedLinear(B,D,bias=_C) + if A.lm_head is not _A:A.lm_head._zero_init=_B + A.mtp_heads=nn.ModuleList([CastedLinear(B,D,bias=_C)for A in range(M)]) + for V in A.mtp_heads:V._zero_init=_B + if O>0: + for W in range(max(0,C-O),C):A.blocks[W].attn.use_xsa=_B + A._init_weights() + def _init_weights(A): + if A.tie_embeddings:nn.init.normal_(A.tok_emb.weight,mean=_E,std=A.tied_embed_init_std) + D=A.num_layers;E=_D/math.sqrt(2*D) + for B in range(D):nn.init.orthogonal_(A.qo_bank.data[B],gain=_D);nn.init.zeros_(A.qo_bank.data[D+B]);nn.init.orthogonal_(A.kv_bank.data[B],gain=_D);nn.init.orthogonal_(A.kv_bank.data[D+B],gain=_D);nn.init.orthogonal_(A.mlp_up_bank.data[B],gain=_D);nn.init.zeros_(A.mlp_down_bank.data[B]);A.qo_bank.data[D+B].mul_(E);A.mlp_down_bank.data[B].mul_(E) + for(F,C)in A.named_modules(): + if isinstance(C,nn.Linear): + if getattr(C,'_zero_init',_C):nn.init.zeros_(C.weight) + elif C.weight.ndim==2 and C.weight.shape[0]>=64 and C.weight.shape[1]>=64:nn.init.orthogonal_(C.weight,gain=_D) + def _get_ve(A,layer_idx,input_ids,ve_cache=_A): + D=input_ids;C=layer_idx;B=ve_cache + if A.ve_shared is _A or C not in A.ve_layer_indices:return + if B is not _A and _L not in B:B[_L]=A.ve_shared(D) + E=B[_L]if B is not _A else A.ve_shared(D);F=A.ve_layer_indices.index(C);return E*A.ve_layer_scales[F].to(dtype=E.dtype) + def _run_block(A,i,x,x0,input_ids,ve_cache,v0,n,recur_scale=_A): + C=recur_scale;D=A._get_ve(i,input_ids,ve_cache);E=A.parallel_start_layer>=0 and i>=A.parallel_start_layer;B,F=A.blocks[i](x,x0,A.qo_bank[i],A.kv_bank[i],A.kv_bank[n+i],A.qo_bank[n+i],A.mlp_up_bank[i],A.mlp_down_bank[i],v_embed=D,v0=v0,parallel=E) + if C is not _A:B=x+C.to(dtype=x.dtype)[_A,_A,:]*(B-x) + return B,F + def forward(A,input_ids,target_ids): + N=target_ids;D=input_ids;G=A.num_layers;B=A.tok_emb(D) + if A.bigram is not _A:B=B+A.bigram(D) + B=F.rms_norm(B,(B.size(-1),));B=A.smear(B);H=B;E=_A;I=[];J={} + for C in range(A.num_encoder_layers): + B,O=A._run_block(C,B,H,D,J,E,G) + if E is _A and O is not _A:E=O + I.append(B) + if A.recur_active and C in A.recur_layer_indices:U=A.recur_layer_indices.index(C);B,P=A._run_block(C,B,H,D,J,E,G,recur_scale=A.recur_scales[U]) + for C in range(A.num_decoder_layers): + V=A.num_encoder_layers+C + if I:B=B+A.skip_weights[C].to(dtype=B.dtype)[_A,_A,:]*I.pop() + B,P=A._run_block(V,B,H,D,J,E,G) + B=A.final_norm(B);Q=B.reshape(-1,B.size(-1));W=N.reshape(-1) + if A.tie_embeddings:R=F.linear(Q,A.tok_emb.weight) + else: + if A.lm_head is _A:raise RuntimeError('lm_head is required when tie_embeddings=False') + R=A.lm_head(Q) + X=A.logit_softcap*torch.tanh(R/A.logit_softcap);K=F.cross_entropy(X.float(),W,reduction=_f) + if A.training and A.mtp_num_heads>0 and A.mtp_loss_weight>_E: + P,Y,Z=B.shape;L=B.new_zeros(());M=0 + for(S,a)in enumerate(A.mtp_heads): + T=Y-(S+1) + if T<=0:continue + b=B[:,:T,:].reshape(-1,Z);c=N[:,S+1:].reshape(-1);d=a(b);e=A.logit_softcap*torch.tanh(d/A.logit_softcap);L=L+F.cross_entropy(e.float(),c,reduction=_f);M+=1 + if M>0:K=K+A.mtp_loss_weight*(L/M) + return K + def forward_logits(A,input_ids): + D=input_ids;G=A.num_layers;B=A.tok_emb(D) + if A.bigram is not _A:B=B+A.bigram(D) + B=F.rms_norm(B,(B.size(-1),));B=A.smear(B);H=B;E=_A;I=[];J={} + for C in range(A.num_encoder_layers): + B,L=A._run_block(C,B,H,D,J,E,G) + if E is _A and L is not _A:E=L + I.append(B) + if A.recur_active and C in A.recur_layer_indices:N=A.recur_layer_indices.index(C);B,O=A._run_block(C,B,H,D,J,E,G,recur_scale=A.recur_scales[N]) + for C in range(A.num_decoder_layers): + P=A.num_encoder_layers+C + if I:B=B+A.skip_weights[C].to(dtype=B.dtype)[_A,_A,:]*I.pop() + B,O=A._run_block(P,B,H,D,J,E,G) + B=A.final_norm(B) + if A.tie_embeddings:M=F.linear(B,A.tok_emb.weight) + else:M=A.lm_head(B) + K=A.logit_softcap*torch.tanh(M/A.logit_softcap) + if hasattr(A,'_eval_temperature')and A._eval_temperature!=_D:K=K/A._eval_temperature + return K +_NG_PRIMES=np.array([np.uint64(36313),np.uint64(27191),np.uint64(51647),np.uint64(81929),np.uint64(131071),np.uint64(175447),np.uint64(209591)],dtype=np.uint64) +_PER_ORDER_CENTERS=np.array([4.5,4.2,3.8,3.5,3.2,3.],dtype=np.float64) +def _init_ngram_tables(args): + A=args;B=A.ngram_order-A.ngram_min_order+1;E=[np.zeros((A.ngram_buckets,),dtype=np.uint32)for B in range(B)];F=[np.zeros((A.ngram_buckets,),dtype=np.uint32)for B in range(B)];G=np.uint64(A.ngram_buckets-1);C=os.environ.get('NGRAM_JM_LAMBDAS','') + if C:D=np.array([float(A)for A in C.split(_K)],dtype=np.float64) + else:D=np.ones(B,dtype=np.float64)/B + H=[np.zeros((A.ngram_buckets,),dtype=np.uint32)for B in range(B)];I=[np.zeros((A.ngram_buckets,),dtype=np.uint32)for B in range(B)];return{_g:B,_h:E,_i:F,_w:H,_x:I,_y:G,_z:_NG_PRIMES,_A0:A.ngram_min_order,_A1:A.ngram_min_count,_A2:D,_O:0,_S:0,_T:0} +def _sync_ngram_tables(ng,device): + A=ng + for B in range(A[_g]): + for(C,D)in[(_h,_w),(_i,_x)]:F=A[C][B]-A[D][B];E=torch.from_numpy(F.view(np.int32).copy()).to(device);dist.all_reduce(E,op=dist.ReduceOp.SUM);G=E.cpu().numpy().view(np.uint32);A[C][B]=A[D][B]+G;A[D][B]=A[C][B].copy() +def _ngram_score_segment(scored_nll,logits_slice,val_np,ws,s,wlen,ng,args,device): + f=val_np;W=scored_nll;D=args;C=ng;g=W.cpu().numpy();O=np.exp(-g);E=len(g);h=np.arange(ws+s+1,ws+wlen+1,dtype=np.int64);I=C[_g];X=C[_h];Y=C[_i];i=C[_y];R=C[_z];q=C[_A0];j=C[_A1] + if D.ngram_entropy: + with torch.no_grad():k=F.log_softmax(logits_slice.float(),dim=-1);S=-(k.exp()*k).sum(dim=-1).cpu().numpy() + else:S=_A + G=[] + for A in range(I): + T=q+A-1;l=h>=T + if not l.any():G.append(_A);continue + H=np.nonzero(l)[0];Z=h[H];a=np.zeros(len(Z),dtype=np.uint64) + for m in range(T):r=f[Z-(T-m)].astype(np.uint64);a^=r*R[m%len(R)] + J=(a&i).astype(np.int64);t=f[Z].astype(np.uint64);K=((a^t*R[T%len(R)])&i).astype(np.int64);G.append((H,J,K)) + if D.ngram_jm: + n=C[_A2];U=np.full((E,I),-_D) + for A in range(I): + if G[A]is _A:continue + H,J,K=G[A];L=X[A][J].astype(np.float64);b=Y[A][K].astype(np.float64);B=L>=float(j) + if B.any():u=H[B];c=np.minimum(b[B],L[B])/np.maximum(L[B],_D);U[u,A]=np.clip(c,_E,_D) + v=U>=0;o=np.zeros(E);d=np.zeros(E) + for A in range(I): + V=v[:,A] + if V.any():o[V]+=n[A]*U[V,A];d[V]+=n[A] + M=np.full(E,-_D);e=d>0;M[e]=np.clip(o[e]/d[e],_E,_D);P=np.full(E,-1,dtype=np.int32) + for A in range(I-1,-1,-1):w=(P<0)&(U[:,A]>=0);P[w]=A + else: + M=np.full(E,-_D);P=np.full(E,-1,dtype=np.int32) + for A in range(I-1,-1,-1): + if G[A]is _A:continue + H,J,K=G[A];L=X[A][J].astype(np.float64);b=Y[A][K].astype(np.float64);B=L>=float(j);Q=B&(M[H]<0) + if Q.any():p=H[Q];c=np.minimum(b[Q],L[Q])/np.maximum(L[Q],_D);M[p]=np.clip(c,_E,_D);P[p]=A + B=M>=0 + if B.any(): + if D.ngram_entropy and S is not _A: + if D.ngram_per_order_centers:x=P[B];y=_PER_ORDER_CENTERS[np.clip(x,0,len(_PER_ORDER_CENTERS)-1)];N=D.ngram_ent_base+D.ngram_ent_range/(_D+np.exp(-D.ngram_ent_scale*(S[B]-y))) + else:N=D.ngram_ent_base+D.ngram_ent_range/(_D+np.exp(-D.ngram_ent_scale*(S[B]-D.ngram_ent_thresh))) + else:N=np.full(int(B.sum()),D.ngram_alpha) + if D.ngram_depth_signal:z=np.float64(min(_D,C[_O]*D.ngram_depth_scale));N=N*z + O[B]=(_D-N)*O[B]+N*M[B] + C[_S]=C.get(_S,0)+int((M>=0).sum());C[_T]=C.get(_T,0)+E;O=np.clip(O,1e-12,_D);W=torch.from_numpy(-np.log(O)).to(dtype=torch.float64,device=device) + for A in range(I): + if G[A]is _A:continue + H,J,K=G[A];np.add.at(X[A],J,1);np.add.at(Y[A],K,1) + C[_O]=C.get(_O,0)+E;return W +def eval_val_sliding(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,eval_seq_len=_A): + Z=batch_seqs;Y=stride;O=val_tokens;N=world_size;M=base_model;K=rank;C=device;A=args;I=eval_seq_len or A.train_seq_len;P=O.numel()-1;a=[A for A in range(0,P,Y)if min(A+I,P)-A>=1];b=len(a);o=b*K//N;p=b*(K+1)//N;c=a[o:p];Q=torch.zeros((),device=C,dtype=torch.float64);L=torch.zeros((),device=C,dtype=torch.float64);R=torch.zeros((),device=C,dtype=torch.float64);J=A.ngram_enabled;E=_A;d=_A + if K==0:print(f"ngram_eval: enabled={J} order={A.ngram_order} min_order={A.ngram_min_order} buckets={A.ngram_buckets} entropy={A.ngram_entropy} jm={A.ngram_jm} per_order_centers={A.ngram_per_order_centers} depth_signal={A.ngram_depth_signal}",flush=_B) + if J:d=O.cpu().numpy();E=_init_ngram_tables(A) + M.eval();q=torch.compile(M.forward_logits,dynamic=_C,fullgraph=_B);e=0 + with torch.inference_mode(): + for f in range(0,len(c),Z): + S=c[f:f+Z];T=len(S);U=torch.zeros(T,I,dtype=torch.int64,device=C);V=torch.zeros(T,I,dtype=torch.int64,device=C);g=[] + for(D,G)in enumerate(S):h=min(G+I,P);B=h-G;g.append(B);i=O[G:h+1].to(dtype=torch.int64,device=C);U[D,:B]=i[:-1];V[D,:B]=i[1:] + with torch.autocast(device_type=_I,dtype=torch.bfloat16):W=q(U) + r=F.cross_entropy(W.reshape(-1,W.size(-1)).float(),V.reshape(-1),reduction='none').reshape(T,I) + for(D,G)in enumerate(S): + B=g[D];H=0 if G==0 else max(B-Y,0);X=r[D,H:B].to(torch.float64) + if J:X=_ngram_score_segment(X,W[D,H:B],d,G,H,B,E,A,C) + Q+=X.sum();L+=float(B-H);j=V[D,H:B];s=U[D,H:B];k=base_bytes_lut[j].to(torch.float64);k+=(has_leading_space_lut[j]&~is_boundary_token_lut[s]).to(torch.float64);R+=k.sum() + e+=1 + if J and dist.is_available()and dist.is_initialized()and N>1: + if e%A.ngram_sync_interval==0:_sync_ngram_tables(E,C) + if J and K==0:l=E.get(_S,0);m=E.get(_T,0);t=1e2*l/max(m,1);print(f"ngram_stats: matches={l}/{m} ({t:.1f}%) tokens_in_cache={E.get(_O,0)}",flush=_B) + if dist.is_available()and dist.is_initialized():dist.all_reduce(Q,op=dist.ReduceOp.SUM);dist.all_reduce(L,op=dist.ReduceOp.SUM);dist.all_reduce(R,op=dist.ReduceOp.SUM) + n=(Q/L).item();u=n/math.log(2.);v=L.item()/R.item();M.train();return n,u*v +def eval_val_sliding_ttt(args,base_model,rank,world_size,device,val_tokens,base_bytes_lut,has_leading_space_lut,is_boundary_token_lut,stride,batch_seqs=32,log0=print): + s=batch_seqs;W=stride;V=val_tokens;O=world_size;N=rank;I=log0;H=device;B=base_model;A=args;G=A.train_seq_len;P=V.numel()-1;Q=A.ttt_chunk_tokens;t=[A for A in range(0,P,W)if min(A+G,P)-A>=W or A==0];K=(P+Q-1)//Q;u=[[]for A in range(K)] + for E in t:Z=min(E+G,P);C=Z-E;L=0 if E==0 else max(C-W,0);AG=E+L;D=min(AG//Q,K-1);u[D].append(E) + I(f"ttt_sliding:start chunks={K} chunk_tokens={Q} total_windows={len(t)} stride={W} ttt_lr={A.ttt_lr} ttt_epochs={A.ttt_epochs} freeze_blocks={A.ttt_freeze_blocks}");R=torch.zeros((),device=H,dtype=torch.float64);J=torch.zeros((),device=H,dtype=torch.float64);S=torch.zeros((),device=H,dtype=torch.float64);v=set(range(min(A.ttt_freeze_blocks,len(B.blocks))));T=[] + for(AH,X)in B.named_parameters(): + w=_C + for U in v: + if f"blocks.{U}."in AH:w=_B;break + if w:X.requires_grad_(_C) + else:X.requires_grad_(_B);T.append(X) + I(f"ttt_sliding:params unfrozen={sum(A.numel()for A in T)} frozen={sum(A.numel()for A in B.parameters()if not A.requires_grad)}") + if A.ttt_optimizer=='adamw':a=torch.optim.AdamW(T,lr=A.ttt_adamw_lr,weight_decay=_E,betas=(.9,.95));I(f"ttt_optimizer:AdamW lr={A.ttt_adamw_lr}") + else:a=torch.optim.SGD(T,lr=A.ttt_lr,momentum=A.ttt_momentum,nesterov=A.ttt_nesterov);I(f"ttt_optimizer:SGD lr={A.ttt_lr} nesterov={A.ttt_nesterov}") + if A.ttt_temperature!=_D:B._eval_temperature=A.ttt_temperature;I(f"ttt_temperature:{A.ttt_temperature}") + else:B._eval_temperature=_D + AI=torch.compile(B.forward_logits,dynamic=_C,fullgraph=_B);x=_E;y=_E;z=_E;e=A.ngram_enabled;f=_A;A0=_A + if e:A0=V.cpu().numpy();f=_init_ngram_tables(A);I(f"ngram:enabled orders={A.ngram_min_order}-{A.ngram_order} entropy={A.ngram_entropy} alpha={A.ngram_alpha} buckets={A.ngram_buckets} min_count={A.ngram_min_count}") + A1=time.perf_counter() + for D in range(K): + b=u[D] + if not b:continue + g=D*Q;AJ=min((D+1)*Q,P);AK=len(b)*N//O;AL=len(b)*(N+1)//O;A2=b[AK:AL];B.eval();A3=0 + with torch.inference_mode(): + for U in range(0,len(A2),s): + h=A2[U:U+s];i=len(h);j=torch.zeros(i,G,dtype=torch.int64,device=H);k=torch.zeros(i,G,dtype=torch.int64,device=H);A4=[] + for(M,E)in enumerate(h):Z=min(E+G,P);C=Z-E;A4.append(C);A5=V[E:Z+1].to(dtype=torch.int64,device=H);j[M,:C]=A5[:-1];k[M,:C]=A5[1:] + with torch.autocast(device_type=_I,dtype=torch.bfloat16):l=AI(j) + AM=F.cross_entropy(l.reshape(-1,l.size(-1)).float(),k.reshape(-1),reduction='none').reshape(i,G) + for(M,E)in enumerate(h): + C=A4[M];L=0 if E==0 else max(C-W,0);m=AM[M,L:C].to(torch.float64) + if e:m=_ngram_score_segment(m,l[M,L:C],A0,E,L,C,f,A,H) + R+=m.sum();J+=float(C-L);A6,AN=k[M,L:C],j[M,L:C];A7=base_bytes_lut[A6].to(torch.float64);A7+=(has_leading_space_lut[A6]&~is_boundary_token_lut[AN]).to(torch.float64);S+=A7.sum() + A3+=1 + if e and dist.is_available()and dist.is_initialized()and O>1: + if A3%A.ngram_sync_interval==0:_sync_ngram_tables(f,H) + AO=D==K-1 + if not AO and A.ttt_epochs>0: + if A.ttt_difficulty and D>0: + AP=R.item()-x;n=J.item()-z;A8=S.item()-y + if n>0 and A8>0:c=AP/n/math.log(2.)*(n/A8) + else:c=1.12 + if c>A.ttt_hard_threshold:Y=A.ttt_hard_epochs + elif c0: + A9=A.ttt_adamw_lr if A.ttt_optimizer=='adamw'else A.ttt_lr;p=A9*.5*(_D+math.cos(math.pi*D/max(K-1,1))) + if A.ttt_lr_floor>0:p=max(p,A9*A.ttt_lr_floor) + for AQ in a.param_groups:AQ[_H]=p + q=o*N//O;AR=o*(N+1)//O;AA=AR-q + for Ae in range(Y): + for AB in range(0,AA,A.ttt_batch_seqs): + AS=min(AB+A.ttt_batch_seqs,AA);AT=q+AB;AU=g+AT*G;AC=g+(q+AS)*G+1 + if AC>V.numel():continue + AD=V[AU:AC].to(device=H,dtype=torch.int64);AV=AD[:-1].reshape(-1,G);AW=AD[1:].reshape(-1,G);a.zero_grad(set_to_none=_B) + with torch.autocast(device_type=_I,dtype=torch.bfloat16):AX=B(AV,AW) + AX.backward() + if O>1: + d=[A.grad for A in T if A.grad is not _A] + if d: + AE=torch._utils._flatten_dense_tensors(d);dist.all_reduce(AE,op=dist.ReduceOp.AVG) + for(AY,AZ)in zip(d,torch._utils._unflatten_dense_tensors(AE,d)):AY.copy_(AZ) + if A.ttt_q_only: + Aa=len(B.blocks) + if B.qo_bank.grad is not _A: + B.qo_bank.grad[Aa:].zero_() + for U in v:B.qo_bank.grad[U].zero_() + if B.kv_bank.grad is not _A:B.kv_bank.grad.zero_() + if B.mlp_up_bank.grad is not _A:B.mlp_up_bank.grad.zero_() + if B.mlp_down_bank.grad is not _A:B.mlp_down_bank.grad.zero_() + torch.nn.utils.clip_grad_norm_(T,A.ttt_grad_clip);a.step() + if N==0 and(D%10==0 or D==K-1):Ab=time.perf_counter()-A1;Ac=R.item()/max(J.item(),1);Ad=Ac/math.log(2.)*(J.item()/max(S.item(),1))if J.item()>0 else _E;I(f" ttt_chunk [{D+1}/{K}] bpb={Ad:.6f} time={Ab:.1f}s") + if dist.is_available()and dist.is_initialized():dist.all_reduce(R,op=dist.ReduceOp.SUM);dist.all_reduce(J,op=dist.ReduceOp.SUM);dist.all_reduce(S,op=dist.ReduceOp.SUM) + r=(R/J).item();AF=r/math.log(2.)*(J.item()/S.item()) + for X in B.parameters():X.requires_grad_(_B) + B.eval();I(f"ttt_sliding:done val_loss={r:.6f} val_bpb={AF:.6f} elapsed={time.perf_counter()-A1:.1f}s");return r,AF +def generate_autoregressive_calib(model,device,num_seqs=64,seq_len=2048,vocab_size=1024,temperature=.8,batch_size=8,seed=42): + F=batch_size;E=num_seqs;D=device;C=model;C.eval();B=torch.Generator(device=D);B.manual_seed(seed);G=[] + with torch.inference_mode(),torch.autocast(device_type=_I,dtype=torch.bfloat16): + for J in range(0,E,F): + H=min(F,E-J);A=torch.randint(0,vocab_size,(H,1),device=D,generator=B) + for O in range(seq_len-1):K=C.forward_logits(A);L=K[:,-1,:];M=torch.softmax(L/temperature,dim=-1);N=torch.multinomial(M,1,generator=B);A=torch.cat([A,N],dim=1) + for I in range(H):G.append(A[I:I+1]) + return G +def collect_hessians_from_tokens(hessian_model,token_seqs,device,gptq_damp=.01): + G=device;F=token_seqs;C=hessian_model;A={};H=[] + for(B,D)in C.named_modules(): + if isinstance(D,CastedLinear): + I=B+'.weight';J=D.weight.shape[1];A[I]=torch.zeros(J,J,dtype=torch.float32,device=_N) + def M(pname): + def B(module,input,output): + B=input[0].detach().float() + if B.ndim==3:B=B.reshape(-1,B.shape[-1]) + A[pname]+=(B.T@B).cpu() + return B + E=D.register_forward_hook(M(I));H.append(E) + C.eval() + with torch.inference_mode(),torch.autocast(device_type=_I,dtype=torch.bfloat16): + for K in F:N=K[:,:-1].to(G);O=K[:,1:].to(G);C(N,O) + for E in H:E.remove() + P=len(F) + for B in A:L=A[B];L/=P;A[B]=L + return A +def _hadamard_matrix(n): + A=torch.ones(1,1) + while A.size(0)0 else _D,dtype=torch.float16);C=torch.clamp(torch.round(B/L.float()),-A,A).to(torch.int8);return C,L +def quantize_int6_gptq(weight,hessian=_A,clip_range=31,block_size=128,gptq_damp=.01): + Q=block_size;P=hessian;G=clip_range;E=weight.float() + if E.ndim!=2 or P is _A:return _quantize_int6_percentile(E,G) + R,H=E.shape;B=P.float().clone();L=torch.diag(B)==0;B[L,L]=1;g=gptq_damp*torch.mean(torch.diag(B));B[torch.arange(H),torch.arange(H)]+=g;I=torch.argsort(torch.diag(B),descending=_B);h=torch.argsort(I);J=E[:,I].clone();J[:,L[I]]=0;B=B[I][:,I];F=torch.linalg.cholesky(B);F=torch.cholesky_inverse(F);F=torch.linalg.cholesky(F,upper=_B);K=_A;S=_A;T=float(_j) + for U in[.999,.9995,.9999,.99999,_D]: + if U<_D:V=torch.quantile(E.abs(),U,dim=1) + else:V=E.abs().amax(dim=1) + W=(V/G).clamp_min(_D/G).to(torch.float16);M=W.float();N=torch.zeros_like(J,dtype=torch.int8);X=J.clone() + for D in range(0,H,Q): + A=min(D+Q,H);O=A-D;Y=X[:,D:A].clone();Z=torch.zeros(R,O,dtype=torch.int8);a=torch.zeros(R,O);b=F[D:A,D:A] + for C in range(O):c=Y[:,C];i=b[C,C];d=torch.clamp(torch.round(c/M),-G,G).to(torch.int8);Z[:,C]=d;e=(c-d.float()*M)/i;Y[:,C:]-=e.unsqueeze(1)*b[C,C:].unsqueeze(0);a[:,C]=e + N[:,D:A]=Z + if A0 else _D,dtype=torch.float16);C=torch.clamp(torch.round(A/L.float()),-B,B).to(torch.int8);return C,L +def _unbank_state_dict(sd,num_layers): + B={};D=num_layers + for(E,C)in sd.items(): + if E==_W: + for A in range(D):B[f"blocks.{A}.attn.c_q.weight"]=C[A];B[f"blocks.{A}.attn.proj.weight"]=C[D+A] + elif E==_X: + for A in range(D):B[f"blocks.{A}.attn.c_k.weight"]=C[A];B[f"blocks.{A}.attn.c_v.weight"]=C[D+A] + elif E==_Y: + for A in range(D):B[f"blocks.{A}.mlp.fc.weight"]=C[A] + elif E==_Z: + for A in range(D):B[f"blocks.{A}.mlp.proj.weight"]=C[A] + else:B[E]=C + return B +def _rebank_state_dict(sd,num_layers,template_sd): + F=template_sd;A=sd;E={};C=num_layers;G=[_A]*(2*C);H=[_A]*(2*C);O=[_A]*C;P=[_A]*C;D=set() + for B in range(C): + I=f"blocks.{B}.attn.c_q.weight" + if I in A:G[B]=A[I];D.add(I) + J=f"blocks.{B}.attn.proj.weight" + if J in A:G[C+B]=A[J];D.add(J) + K=f"blocks.{B}.attn.c_k.weight" + if K in A:H[B]=A[K];D.add(K) + L=f"blocks.{B}.attn.c_v.weight" + if L in A:H[C+B]=A[L];D.add(L) + M=f"blocks.{B}.mlp.fc.weight" + if M in A:O[B]=A[M];D.add(M) + N=f"blocks.{B}.mlp.proj.weight" + if N in A:P[B]=A[N];D.add(N) + E[_W]=torch.stack(G).to(dtype=F[_W].dtype);E[_X]=torch.stack(H).to(dtype=F[_X].dtype);E[_Y]=torch.stack(O).to(dtype=F[_Y].dtype);E[_Z]=torch.stack(P).to(dtype=F[_Z].dtype) + for(Q,R)in A.items(): + if Q not in D:E[Q]=R + return E +class _HessianAttn(nn.Module): + def __init__(A,dim,num_heads,num_kv_heads,rope_base,qk_gain_init):D=num_kv_heads;C=num_heads;B=dim;super().__init__();A.num_heads,A.num_kv_heads=C,D;A.head_dim=B//C;E=D*A.head_dim;A.c_q=CastedLinear(B,B,bias=_C);A.c_k=CastedLinear(B,E,bias=_C);A.c_v=CastedLinear(B,E,bias=_C);A.proj=CastedLinear(B,B,bias=_C);A.q_gain=nn.Parameter(torch.full((C,),qk_gain_init,dtype=torch.float32));A.rope_dims=0;A.rotary=Rotary(A.head_dim,base=rope_base,train_seq_len=1024);A.use_xsa=_C + def _xsa_efficient(K,y,v):A,B,C,D=y.shape;E=v.size(-2);I=C//E;G=y.reshape(A,B,E,I,D);H=F.normalize(v,dim=-1).unsqueeze(-2);J=(G*H).sum(dim=-1,keepdim=_B)*H;return(G-J).reshape(A,B,C,D) + def forward(A,x,v_embed=_A): + I=v_embed;G,E,L=x.shape;B=A.c_q(x).reshape(G,E,A.num_heads,A.head_dim);C=A.c_k(x).reshape(G,E,A.num_kv_heads,A.head_dim);D=A.c_v(x) + if I is not _A:D=D+I + D=D.reshape(G,E,A.num_kv_heads,A.head_dim);B=F.rms_norm(B,(B.size(-1),));C=F.rms_norm(C,(C.size(-1),));J,K=A.rotary(E,x.device,B.dtype);B=apply_rotary_emb(B,J,K,A.rope_dims);C=apply_rotary_emb(C,J,K,A.rope_dims);B=B*A.q_gain.to(dtype=B.dtype)[_A,_A,:,_A] + if HAS_FA3:H=flash_attn_3_func(B,C,D,causal=_B) + else:M=B.transpose(1,2);N=C.transpose(1,2);O=D.transpose(1,2);H=F.scaled_dot_product_attention(M,N,O,is_causal=_B).transpose(1,2) + if A.use_xsa:H=A._xsa_efficient(H,D) + return A.proj(H.reshape(G,E,L)) +class _HessianMLP(nn.Module): + def __init__(B,dim,mlp_mult):C=mlp_mult;A=dim;super().__init__();B.fc=CastedLinear(A,int(C*A),bias=_C);B.proj=CastedLinear(int(C*A),A,bias=_C) + def forward(A,x):return A.proj(F.leaky_relu(A.fc(x),negative_slope=.5).square()) +class _HessianBlock(nn.Module): + def __init__(A,dim,num_heads,num_kv_heads,mlp_mult,rope_base,qk_gain_init,layer_idx=0,ln_scale=_C):B=dim;super().__init__();A.attn_norm=RMSNorm();A.mlp_norm=RMSNorm();A.attn=_HessianAttn(B,num_heads,num_kv_heads,rope_base,qk_gain_init);A.mlp=_HessianMLP(B,mlp_mult);A.attn_scale=nn.Parameter(torch.ones(B,dtype=torch.float32));A.mlp_scale=nn.Parameter(torch.ones(B,dtype=torch.float32));A.resid_mix=nn.Parameter(torch.stack((torch.ones(B),torch.zeros(B))).float());A.ln_scale_factor=_D/math.sqrt(layer_idx+1)if ln_scale else _D + def forward(A,x,x0,v_embed=_A):D=A.resid_mix.to(dtype=x.dtype);C=D[0][_A,_A,:]*x+D[1][_A,_A,:]*x0;E=A.attn(A.attn_norm(C)*A.ln_scale_factor,v_embed=v_embed);B=C+A.attn_scale.to(dtype=C.dtype)[_A,_A,:]*E;B=B+A.mlp_scale.to(dtype=B.dtype)[_A,_A,:]*A.mlp(A.mlp_norm(B)*A.ln_scale_factor);return B +class _HessianGPT(nn.Module): + def __init__(A,vocab_size,num_layers,model_dim,num_heads,num_kv_heads,mlp_mult,tie_embeddings,logit_softcap,rope_base,qk_gain_init,bigram_vocab_size=0,bigram_dim=128,xsa_last_n=0,rope_dims=0,ln_scale=_C,ve_enabled=_C,ve_dim=128,ve_layers=_c,trigram=_C): + K=xsa_last_n;J=bigram_vocab_size;I=rope_base;H=tie_embeddings;G=num_kv_heads;F=rope_dims;E=num_heads;D=vocab_size;C=num_layers;B=model_dim;super().__init__();A.tie_embeddings=H;A.logit_softcap=logit_softcap;A.num_layers=C;A.tok_emb=nn.Embedding(D,B);A.bigram=BigramHashEmbedding(J,bigram_dim,B,trigram=trigram)if J>0 else _A;A.smear=SmearGate(B);A.num_encoder_layers=C//2;A.num_decoder_layers=C-A.num_encoder_layers;A.num_skip_weights=min(A.num_encoder_layers,A.num_decoder_layers);A.skip_weights=nn.Parameter(torch.ones(A.num_skip_weights,B,dtype=torch.float32));A.blocks=nn.ModuleList([_HessianBlock(B,E,G,mlp_mult,I,qk_gain_init,layer_idx=A,ln_scale=ln_scale)for A in range(C)]) + if F>0: + M=B//E + for L in A.blocks:L.attn.rope_dims=F;L.attn.rotary=Rotary(M,base=I,train_seq_len=1024,rope_dims=F) + if K>0: + for N in range(max(0,C-K),C):A.blocks[N].attn.use_xsa=_B + O=G*(B//E);A.ve_layer_indices=[int(A)for A in ve_layers.split(_K)if A.strip()]if ve_enabled else[] + if A.ve_layer_indices:A.ve_shared=ValueEmbedding(D,ve_dim,O);A.ve_layer_scales=nn.ParameterList([nn.Parameter(torch.ones(1,dtype=torch.float32))for A in A.ve_layer_indices]) + else:A.ve_shared=_A;A.ve_layer_scales=nn.ParameterList() + A.final_norm=RMSNorm();A.lm_head=_A if H else CastedLinear(B,D,bias=_C) + def _get_ve(A,layer_idx,input_ids,ve_cache): + C=layer_idx;B=ve_cache + if A.ve_shared is _A or C not in A.ve_layer_indices:return + if _L not in B:B[_L]=A.ve_shared(input_ids) + D=A.ve_layer_indices.index(C);return B[_L]*A.ve_layer_scales[D].to(dtype=B[_L].dtype) + def forward(B,input_ids,target_ids): + D=input_ids;A=B.tok_emb(D) + if B.bigram is not _A:A=A+B.bigram(D) + A=F.rms_norm(A,(A.size(-1),));A=B.smear(A);H=A;E=[];I={} + for C in range(B.num_encoder_layers):G=B._get_ve(C,D,I);A=B.blocks[C](A,H,v_embed=G);E.append(A) + for C in range(B.num_decoder_layers): + J=B.num_encoder_layers+C + if E:A=A+B.skip_weights[C].to(dtype=A.dtype)[_A,_A,:]*E.pop() + G=B._get_ve(J,D,I);A=B.blocks[J](A,H,v_embed=G) + A=B.final_norm(A);K=A.reshape(-1,A.size(-1));L=target_ids.reshape(-1);M=F.linear(K,B.tok_emb.weight)if B.tie_embeddings else B.lm_head(K);N=B.logit_softcap*torch.tanh(M/B.logit_softcap);return F.cross_entropy(N.float(),L,reduction=_f) +def collect_hessians(hessian_model,train_loader,args,device,grad_accum_steps,num_batches=256): + G=num_batches;C=hessian_model;A={};H=[] + for(D,E)in C.named_modules(): + if isinstance(E,CastedLinear): + I=D+'.weight';J=E.weight.shape[1];A[I]=torch.zeros(J,J,dtype=torch.float32,device=_N) + def K(pname): + def B(module,input,output): + B=input[0].detach().float() + if B.ndim==3:B=B.reshape(-1,B.shape[-1]) + A[pname]+=(B.T@B).cpu() + return B + F=E.register_forward_hook(K(I));H.append(F) + C.eval() + with torch.inference_mode(),torch.autocast(device_type=_I,dtype=torch.bfloat16): + for O in range(G):L,M=train_loader.next_batch(args.train_batch_tokens,args.train_seq_len,grad_accum_steps);C(L,M) + for F in H:F.remove() + for D in A:B=A[D];B/=G;N=gptq_damp*torch.diag(B).mean().clamp_min(1e-06);B+=N*torch.eye(B.shape[0]);A[D]=B + C.train();return A +def mixed_quantize_int6(state_dict,int6_cats,hessians=_A,clip_ranges=_A,gptq_damp=.01): + J=clip_ranges;I=hessians;H=state_dict;K=max((int(A.split('.')[1])for A in H if A.startswith('blocks.')),default=0)+1;P=set(range(K-2,K));C={};D={} + for(A,M)in H.items(): + B=M.detach().cpu().contiguous();N=_classify_param(A) + if not B.is_floating_point()or B.numel()<=65536:C[A]=B.to(torch.float16)if B.is_floating_point()else B;D[A]=_R;continue + if any(B in A for B in CONTROL_TENSOR_NAME_PATTERNS):C[A]=B.float();D[A]=_A3;continue + if N in int6_cats and B.ndim>=1: + E=J.get(A,31)if J else 31;L=I.get(A)if I else _A + if L is not _A:F,G=quantize_int6_gptq(B,hessian=L,clip_range=E,gptq_damp=gptq_damp) + else:F,G=quantize_int6_per_row(B,clip_range=E) + C[A+_a]=F;C[A+_b]=G;O={15:'int5',31:'int6',63:'int7'}.get(E,f"int_cr{E}");D[A]={_k:O} + else:F,G=quantize_float_tensor(B);C[A+_a]=F;C[A+_b]=G;D[A]={_k:'int8'} + return C,D +def dequantize_mixed_int6(result,meta,template_sd): + F=result;B={} + for(A,I)in template_sd.items(): + H=meta.get(A) + if H is _A:continue + C=I.dtype + if H in(_R,_A3,'passthrough_fp16'): + D=F[A] + if D.dtype==torch.float16 and C in(torch.float32,torch.bfloat16):D=D.to(C) + B[A]=D;continue + E,G=F[A+_a],F[A+_b] + if G.ndim>0:B[A]=(E.float()*G.float().view(E.shape[0],*[1]*(E.ndim-1))).to(C) + else:B[A]=(E.float()*float(G.item())).to(C) + return B +def main(): + BQ='final_model.int6.ptz';BP='final_model.pt';BO='mtp_heads';BN='WORLD_SIZE';U='base_lr';x=Path(__file__).read_text(encoding=_M);A=Hyperparameters();Q='RANK'in os.environ and BN in os.environ;M=int(os.environ.get('RANK',_F));I=int(os.environ.get(BN,_J));BR=int(os.environ.get('LOCAL_RANK',_F)) + if I<=0:raise ValueError(f"WORLD_SIZE must be positive, got {I}") + if 8%I!=0:raise ValueError(f"WORLD_SIZE={I} must divide 8 so grad_accum_steps stays integral") + L=8//I;Aj=_D/L + if not torch.cuda.is_available():raise RuntimeError('CUDA is required') + E=torch.device(_I,BR);torch.cuda.set_device(E) + if Q:dist.init_process_group(backend='nccl',device_id=E);dist.barrier() + y=M==0;torch.backends.cuda.matmul.allow_tf32=_B;torch.backends.cudnn.allow_tf32=_B;from torch.backends.cuda import enable_cudnn_sdp as Ak,enable_flash_sdp as Al,enable_math_sdp as Am,enable_mem_efficient_sdp as An + if HAS_FA3:Ak(_C);Al(_B);An(_C);Am(_C) + else:Ak(_B);Al(_B);An(_B);Am(_B) + z=_A + if y:os.makedirs('logs',exist_ok=_B);z=f"logs/{A.run_id}.txt";print(z) + def C(msg,console=_B): + if not y:return + if console:print(msg) + if z is not _A: + with open(z,'a',encoding=_M)as A:print(msg,file=A) + C(x,console=_C);C('='*100,console=_C);C(f"Running Python {sys.version}",console=_C);C(f"Running PyTorch {torch.__version__}",console=_C);C(subprocess.run(['nvidia-smi'],stdout=subprocess.PIPE,stderr=subprocess.PIPE,text=_B,check=_C).stdout,console=_C);C('='*100,console=_C);random.seed(A.seed);np.random.seed(A.seed);torch.manual_seed(A.seed);torch.cuda.manual_seed_all(A.seed) + if not A.tokenizer_path.endswith('.model'):raise ValueError(f"Script only setup for SentencePiece .model file: {A.tokenizer_path}") + A8=spm.SentencePieceProcessor(model_file=A.tokenizer_path) + if int(A8.vocab_size())!=A.vocab_size:raise ValueError(f"VOCAB_SIZE={A.vocab_size} does not match tokenizer vocab_size={int(A8.vocab_size())}") + Ao=Path(A.data_path).resolve();BS=len(list(Ao.glob(_l)));A9=A.eval_seq_len if A.eval_seq_len>0 else A.train_seq_len;BT=max(A.train_seq_len,A9);R=load_validation_tokens(A.val_files,BT);V,W,X=build_sentencepiece_luts(A8,A.vocab_size,E);C(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={A.tokenizer_path}");C(f"train_loader:dataset:{Ao.name} train_shards:{BS}");C(f"val_loader:shards pattern={A.val_files} tokens:{R.numel()-1}");CastedLinear._qat_enabled=A.qat_enabled;B=GPT(vocab_size=A.vocab_size,num_layers=A.num_layers,model_dim=A.model_dim,num_heads=A.num_heads,num_kv_heads=A.num_kv_heads,mlp_mult=A.mlp_mult,tie_embeddings=A.tie_embeddings,tied_embed_init_std=A.tied_embed_init_std,logit_softcap=A.logit_softcap,rope_base=A.rope_base,qk_gain_init=A.qk_gain_init,mtp_num_heads=A.mtp_num_heads,mtp_loss_weight=A.mtp_loss_weight,bigram_vocab_size=A.bigram_vocab_size,bigram_dim=A.bigram_dim,xsa_last_n=A.xsa_last_n,rope_dims=A.rope_dims,ln_scale=A.ln_scale,dtg=A.dtg_enabled,ve_enabled=A.ve_enabled,ve_dim=A.ve_dim,ve_layers=A.ve_layers,gated_attention=A.gated_attention,value_residual=A.value_residual,recur_layers=A.recur_layers,parallel_start_layer=A.parallel_start_layer,trigram=A.trigram_enabled).to(E).bfloat16();B.qo_bank.data=B.qo_bank.data.float();B.kv_bank.data=B.kv_bank.data.float();B.mlp_up_bank.data=B.mlp_up_bank.data.float();B.mlp_down_bank.data=B.mlp_down_bank.data.float() + for Y in B.modules(): + if isinstance(Y,CastedLinear):Y.float() + restore_low_dim_params_to_fp32(B) + if SKIP_COMPILE:AA=B;C('compile:SKIPPED (SKIP_COMPILE=1)') + else:AA=torch.compile(B,dynamic=_C,fullgraph=_B) + A0=AA;BU=[B.qo_bank,B.kv_bank,B.mlp_up_bank,B.mlp_down_bank];BV=list(B.blocks.named_parameters());N=[A for(B,A)in BV if A.ndim<2 or any(A in B for A in CONTROL_TENSOR_NAME_PATTERNS)] + if B.skip_weights.numel()>0:N.append(B.skip_weights) + N.append(B.smear.gate) + if B.bigram is not _A:N.append(B.bigram.scale) + S=A.tied_embed_lr if A.tie_embeddings else A.embed_lr;AB=[{_G:[B.tok_emb.weight],_H:S,U:S}] + if B.bigram is not _A: + AB.append({_G:[B.bigram.embed.weight],_H:S,U:S}) + if B.bigram.proj is not _A:N.append(B.bigram.proj.weight) + if B.ve_shared is not _A: + AB.append({_G:[B.ve_shared.embed.weight],_H:S,U:S}) + if B.ve_shared.proj is not _A:N.append(B.ve_shared.proj.weight) + N.append(B.ve_shared.scale) + for Z in B.ve_layer_scales:N.append(Z) + A1=torch.optim.AdamW(AB,betas=(A.beta1,A.beta2),eps=A.adam_eps,weight_decay=A.adam_wd,fused=_B);j=Muon(BU,lr=A.matrix_lr,momentum=A.muon_momentum,backend_steps=A.muon_backend_steps,weight_decay=A.muon_wd,muon_eq_r=A.muon_eq_r) + for a in j.param_groups:a[U]=A.matrix_lr + Ap=torch.optim.AdamW([{_G:N,_H:A.scalar_lr,U:A.scalar_lr}],betas=(A.beta1,A.beta2),eps=A.adam_eps,weight_decay=A.adam_wd,fused=_B);A2=list(A1.param_groups[0][_G]) + for BW in A1.param_groups[1:]:A2.extend(BW[_G]) + A2.extend(N);k=_A + if B.lm_head is not _A:k=torch.optim.Adam([{_G:[B.lm_head.weight],_H:A.head_lr,U:A.head_lr}],betas=(A.beta1,A.beta2),eps=A.adam_eps,fused=_B);A2.append(B.lm_head.weight) + b=[A1,j,Ap] + if k is not _A:b.append(k) + BX=sum(A.numel()for A in B.parameters());BY=sum(A.numel()for A in B.mtp_heads.parameters());C(f"model_params:{BX}");C(f"mtp_num_heads:{A.mtp_num_heads} mtp_loss_weight:{A.mtp_loss_weight} mtp_params:{BY}");BZ=[A for(A,B)in enumerate(B.blocks)if B.attn.use_xsa];C(f"XSA:last_{A.xsa_last_n} active_layers:{BZ}");C(f"world_size:{I} grad_accum_steps:{L}");C('sdp_backends:cudnn=False flash=True mem_efficient=False math=False');C(f"attention_mode:gqa num_heads:{A.num_heads} num_kv_heads:{A.num_kv_heads}");C(f"tie_embeddings:{A.tie_embeddings} embed_lr:{S} head_lr:{A.head_lr if B.lm_head is not _A else _E} matrix_lr:{A.matrix_lr} scalar_lr:{A.scalar_lr}");C(f"train_batch_tokens:{A.train_batch_tokens} train_seq_len:{A.train_seq_len} iterations:{A.iterations} warmup_steps:{A.warmup_steps} max_wallclock_seconds:{A.max_wallclock_seconds:.3f}");C(f"seed:{A.seed}");AC=DistributedTokenLoader(A.train_files,M,I,E) + def l(): + for A in b:A.zero_grad(set_to_none=_B) + m=1e3*A.max_wallclock_seconds if A.max_wallclock_seconds>0 else _A + def Ba(step,elapsed_ms): + C=elapsed_ms;B=step + if A.warmdown_iters<=0:return _D + if m is _A:F=max(A.iterations-A.warmdown_iters,0);return max((A.iterations-B)/max(A.warmdown_iters,1),_E)if F<=B0: + Bb={A:B.detach().cpu().clone()for(A,B)in B.state_dict().items()};Bc=[copy.deepcopy(A.state_dict())for A in b];A0.train() + for AD in range(A.warmup_steps): + l() + for Bd in range(L): + n,o=AC.next_batch(A.train_batch_tokens,A.train_seq_len,L) + with torch.autocast(device_type=_I,dtype=torch.bfloat16,enabled=_B):Be=A0(n,o) + (Be*Aj).backward() + if Q: + for J in B.parameters(): + if J.grad is not _A:dist.all_reduce(J.grad,op=dist.ReduceOp.AVG) + for p in b:p.step() + l() + if A.warmup_steps<=20 or(AD+1)%10==0 or AD+1==A.warmup_steps:C(f"warmup_step:{AD+1}/{A.warmup_steps}") + B.load_state_dict(Bb,strict=_B) + for(p,Bf)in zip(b,Bc,strict=_B):p.load_state_dict(Bf) + l();AC=DistributedTokenLoader(A.train_files,M,I,E) + AE=_A;Aq=0;from collections import deque;q=deque(maxlen=A.lawa_k);Ar={A:B.detach().float().clone()for(A,B)in B.state_dict().items()};As=.997;c=_E;d=_A;torch.cuda.synchronize();A3=time.perf_counter();G=0 + while _B: + At=G==A.iterations or d is not _A and G>=d;Bg=At or A.val_loss_every>0 and G%A.val_loss_every==0 + if Bg:torch.cuda.synchronize();c+=1e3*(time.perf_counter()-A3);Bh,Bi=eval_val(A,A0,M,I,E,L,R,V,W,X);C(f"step:{G}/{A.iterations} val_loss:{Bh:.4f} val_bpb:{Bi:.4f} train_time:{c:.0f}ms step_avg:{c/max(G,1):.2f}ms");torch.cuda.synchronize();A3=time.perf_counter() + if At: + if d is not _A and G0 and e65536 and(_V in D or _U in D or'bank'in D):Bl=J.float().abs().amax(dim=1,keepdim=_B).detach();Z=Bl/31.;AG=AG+(J.float()**2*Z**2/12.).mean() + T=T+A.crown_q_lambda*AG + AF+=T.detach();(T*Aj).backward() + AF/=L;Au=min(G/A.muon_momentum_warmup_steps,_D)if A.muon_momentum_warmup_steps>0 else _D;Bm=(1-Au)*A.muon_momentum_warmup_start+Au*A.muon_momentum + for a in j.param_groups:a[_o]=Bm + for p in b: + for a in p.param_groups:a[_H]=a[U]*e + if A.grad_clip_norm>0:torch.nn.utils.clip_grad_norm_(B.parameters(),A.grad_clip_norm) + j.launch_reduce_scatters() + if Q: + for J in A2: + if J.grad is not _A:dist.all_reduce(J.grad,op=dist.ReduceOp.AVG) + A1.step();Ap.step() + if k is not _A:k.step() + j.step();l() + with torch.no_grad(): + for(D,AH)in B.state_dict().items():Ar[D].mul_(As).add_(AH.detach().float(),alpha=_D-As) + G+=1;AI=c+1e3*(time.perf_counter()-A3) + if A.swa_enabled and e<.2 and G%A.swa_every==0: + if AE is _A:AE={A:B.detach().cpu().clone()for(A,B)in B.state_dict().items()};Aq=1;C(f"swa:start step:{G}") + else: + for(D,AH)in B.state_dict().items():AE[D]+=AH.detach().cpu() + Aq+=1 + if A.lawa_enabled and G%A.lawa_freq==0:q.append({A:B.detach().cpu().clone()for(A,B)in B.state_dict().items()}) + Bn=A.train_log_every>0 and(G<=10 or G%A.train_log_every==0 or d is not _A) + if Bn:C(f"step:{G}/{A.iterations} train_loss:{AF.item():.4f} train_time:{AI:.0f}ms step_avg:{AI/G:.2f}ms") + AJ=m is not _A and AI>=m + if Q and m is not _A:Av=torch.tensor(int(AJ),device=E);dist.all_reduce(Av,op=dist.ReduceOp.MAX);AJ=bool(Av.item()) + if d is _A and AJ:d=G + C(f"peak memory allocated: {torch.cuda.max_memory_allocated()//1024//1024} MiB reserved: {torch.cuda.max_memory_reserved()//1024//1024} MiB") + if A.lawa_enabled and len(q)>1: + C(f"lawa:applying LAWA averaging k={len(q)}");A4=B.state_dict();O={A:torch.zeros(B.shape,dtype=torch.float32,device=_N)for(A,B)in A4.items()} + for Bo in q: + for D in O:O[D]+=Bo[D].float() + for D in O:O[D]/=len(q);O[D]=O[D].to(dtype=A4[D].dtype) + B.load_state_dict(O,strict=_B) + else:C('ema:applying EMA weights');A4=B.state_dict();O={A:B.to(dtype=A4[A].dtype)for(A,B)in Ar.items()};B.load_state_dict(O,strict=_B) + torch.cuda.synchronize();Bp=time.perf_counter();Bq,Br=eval_val(A,AA,M,I,E,L,R,V,W,X);torch.cuda.synchronize();C(f"DIAGNOSTIC post_ema val_loss:{Bq:.4f} val_bpb:{Br:.4f} eval_time:{1e3*(time.perf_counter()-Bp):.0f}ms") + if A.prequant_ttt: + C('prequant_ttt:starting discriminative fine-tuning...');Bs=time.perf_counter();Aw=generate_autoregressive_calib(B,E,num_seqs=A.gptq_calib_batches,seq_len=A.train_seq_len,vocab_size=A.vocab_size,temperature=.8,batch_size=8,seed=A.seed+7);Ax=[A.clone().detach()for A in Aw];del Aw + for Y in B.modules(): + Ay={} + for(A5,AK)in Y.named_buffers(recurse=_C): + if AK is not _A and AK.is_inference():Ay[A5]=AK.clone() + for(A5,Bt)in Ay.items():delattr(Y,A5);Y.register_buffer(A5,Bt) + Az=A.num_layers;A6=[];AL=set() + for f in range(Az): + AM=[B for(A,B)in B.named_parameters()if B.requires_grad and(f"blocks.{f}."in A or f"recur_scales.{f}"in A)] + if AM: + Bu=.3+.7*(f/max(1,Az-1));A6.append({_G:AM,_H:A.prequant_ttt_lr*Bu}) + for J in AM:AL.add(id(J)) + AN=[A for(B,A)in B.named_parameters()if B in(_W,_X,_Y,_Z)and A.requires_grad] + if AN: + A6.append({_G:AN,_H:A.prequant_ttt_lr*.65}) + for J in AN:AL.add(id(J)) + A_=[A for A in B.parameters()if A.requires_grad and id(A)not in AL] + if A_:A6.append({_G:A_,_H:A.prequant_ttt_lr*.1}) + AO=torch.optim.AdamW(A6,weight_decay=.01);B.train() + for Bv in range(A.prequant_ttt_epochs): + B0=_E;B1=0 + for B2 in Ax: + n=B2[:,:-1].to(E);o=B2[:,1:].to(E) + with torch.autocast(device_type=_I,dtype=torch.bfloat16):B3=B.forward_logits(n);T=F.cross_entropy(B3.reshape(-1,B3.size(-1)),o.reshape(-1)) + AO.zero_grad();T.backward();torch.nn.utils.clip_grad_norm_(B.parameters(),_D);AO.step();B0+=T.item();B1+=1 + C(f"prequant_ttt:epoch {Bv+1}/{A.prequant_ttt_epochs} loss:{B0/max(1,B1):.4f}") + B.eval();del Ax,AO;torch.cuda.empty_cache();C(f"prequant_ttt:done in {time.perf_counter()-Bs:.1f}s") + B4=B.state_dict();AP={A:B for(A,B)in B4.items()if BO not in A};B5=sum(int(B.numel())for(A,B)in B4.items()if BO in A) + if B5>0:C(f"export_excluding_mtp_params:{B5}") + if y:torch.save(AP,BP);Bw=os.path.getsize(BP);AQ=len(x.encode(_M));C(f"Serialized model: {Bw} bytes");C(f"Code size: {AQ} bytes") + if SKIP_QUANTIZE or A.iterations<1000 and A.max_wallclock_seconds>0: + C('SKIP_QUANTIZE: skipping quantization, roundtrip eval, sliding window, and TTT') + if Q:dist.destroy_process_group() + return + B6={A:B.detach().cpu()for(A,B)in AP.items()};r=_unbank_state_dict(B6,A.num_layers);s=_A + if A.gptq_full_hessian: + C('gptq:building non-banked model for Hessian collection...');g=_HessianGPT(vocab_size=A.vocab_size,num_layers=A.num_layers,model_dim=A.model_dim,num_heads=A.num_heads,num_kv_heads=A.num_kv_heads,mlp_mult=A.mlp_mult,tie_embeddings=A.tie_embeddings,logit_softcap=A.logit_softcap,rope_base=A.rope_base,qk_gain_init=A.qk_gain_init,bigram_vocab_size=A.bigram_vocab_size,bigram_dim=A.bigram_dim,xsa_last_n=A.xsa_last_n,rope_dims=A.rope_dims,ln_scale=A.ln_scale,ve_enabled=A.ve_enabled,ve_dim=A.ve_dim,ve_layers=A.ve_layers,trigram=A.trigram_enabled).to(E).bfloat16() + for t in g.modules(): + if isinstance(t,CastedLinear):t.float() + restore_low_dim_params_to_fp32(g);g.load_state_dict({A:B.to(E)for(A,B)in r.items()if A in g.state_dict()},strict=_C);C(f"gptq:generating autoregressive calibration data ({A.gptq_calib_batches} seqs x {A.train_seq_len} tokens, temp=0.8)...");B.load_state_dict(AP,strict=_C);Bx=time.perf_counter();AR=generate_autoregressive_calib(B,E,num_seqs=A.gptq_calib_batches,seq_len=A.train_seq_len,vocab_size=A.vocab_size,temperature=.8,batch_size=8,seed=A.seed);C(f"gptq:generated {len(AR)} sequences in {time.perf_counter()-Bx:.1f}s");C('gptq:collecting hessians from autoregressive data...');s=collect_hessians_from_tokens(g,AR,E,gptq_damp=A.gptq_damp);C(f"gptq:collected hessians for {len(s)} layers (AR self-gen)");del AR;del g;torch.cuda.empty_cache() + if A.hadamard_rotation:r=_apply_hadamard_rotation(r,A.model_dim,C) + h={} + if A.mixed_bitwidth and s: + AS={} + for(D,By)in s.items():AS[D]=torch.diag(By).mean().item() + B7=sorted(AS.items(),key=lambda x:x[1]);u=len(B7);AT=max(1,u//5);AU=max(1,u//5) + for(f,(D,CD))in enumerate(B7): + if f=u-AU:h[D]=63 + else:h[D]=31 + C(f"mixed_bitwidth: {AT} int5, {u-AT-AU} int6, {AU} int7 (of {u} layers)") + for(D,B8)in sorted(h.items(),key=lambda x:x[1]):Bz={15:'int5',31:'int6',63:'int7'}.get(B8,f"cr{B8}");C(f" {D}: {Bz} (sensitivity={AS[D]:.4f})") + P,AV=mixed_quantize_int6(r,{_U,_V},hessians=s,clip_ranges=h if h else _A,gptq_damp=A.gptq_damp);AW=float(os.environ.get('TARGET_MB','15.9'));B_=len(x.encode(_M));K=[] + for(D,B9)in AV.items(): + if not(isinstance(B9,dict)and B9.get(_k,'').startswith('int')):continue + AX,BA=D+_a,D+_b + if AX not in P or BA not in P:continue + v,Z=P[AX],P[BA] + if Z.ndim>0: + AY=v.abs()==1 + if AY.any(): + C0=torch.arange(v.shape[0]).unsqueeze(1).expand_as(v)[AY];C1=torch.arange(v.numel()).reshape(v.shape)[AY];C2=Z.float()[C0].pow(2) + for(C3,C4)in zip(C1.tolist(),C2.tolist()):K.append((AX,C3,C4)) + if K: + K.sort(key=lambda x:x[2]) + def w(n): + A={A:B.clone()for(A,B)in P.items()} + for B in range(min(n,len(K))):A[K[B][0]].view(-1)[K[B][1]]=0 + C=io.BytesIO();torch.save({'w':A,'m':AV},C);D=C.getvalue();E=brotli.compress(D,quality=11)if _HAS_BROTLI else lzma.compress(D,preset=9);return len(E)+B_,A + BB,_=w(0);AZ=int(AW*1024*1024);C(f"selective_prune: {len(K)} ±1 candidates, unpruned={BB/1048576:.2f}MB target={AW}MB") + if BB<=AZ:C('selective_prune: already fits, no pruning needed') + else: + BC,_=w(len(K));C(f"selective_prune: full ±1 prune={BC/1048576:.2f}MB") + if BC>AZ:C('selective_prune: even full prune not enough, applying all');_,P=w(len(K)) + else: + i,Aa=0,len(K) + while i0 and A.eval_stride