From a082d7c8dc0f28edfe4d5702111832714cfd5ec5 Mon Sep 17 00:00:00 2001 From: ranausmanai Date: Sun, 5 Apr 2026 17:36:38 +0500 Subject: [PATCH 1/5] =?UTF-8?q?Non-record:=20Focal=20Loss=20for=20LM=20Pre?= =?UTF-8?q?training=20=E2=80=94=201.1567=20int8=20BPB=20on=20RTX=204000=20?= =?UTF-8?q?Ada?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply focal loss (Lin et al. 2017) to language model pretraining: replace standard cross-entropy with (1-pt)^gamma * CE to focus on hard-to-predict tokens. Combined with cosine LR schedule and asymmetric encoder-decoder split, achieves 1.1567 int8 BPB at 5000 steps on a single RTX 4000 Ada using baseline code — within 0.037 of the 8xH100 SOTA record. 55+ experiments across 13 rounds validate the finding. See PRs #1275 and #1073 for prior work on asymmetric split and M4 MacBook experiments. Co-Authored-By: Claude Opus 4.6 --- .../README.md | 137 ++++++++++++++++++ .../submission.json | 14 ++ 2 files changed, 151 insertions(+) create mode 100644 records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md create mode 100644 records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/submission.json diff --git a/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md new file mode 100644 index 0000000000..470b3456ce --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md @@ -0,0 +1,137 @@ +# Focal Loss for Language Model Pretraining: 1.1567 int8 BPB on RTX 4000 Ada + +## TL;DR + +Applying **focal loss** (Lin et al. 2017) to language model pretraining gives massive, monotonic improvements. Combined with cosine LR scheduling and asymmetric encoder-decoder split, achieves **1.1567 int8 BPB at 5000 steps on a single RTX 4000 Ada** using the baseline `train_gpt.py` — within 0.037 of the 8xH100 SOTA record (1.1194 BPB). + +## The Core Finding + +Standard cross-entropy treats all tokens equally. But in natural language, most tokens are easy to predict (articles, common words, punctuation). Focal loss down-weights easy tokens and focuses the model on hard-to-predict tokens: + +```python +# Standard cross-entropy: +loss = F.cross_entropy(logits, targets, reduction="mean") + +# Focal loss (our change): +ce = F.cross_entropy(logits, targets, reduction="none") +pt = torch.exp(-ce) # probability of correct class +focal_weight = (1 - pt) ** gamma +loss = (focal_weight * ce).mean() +``` + +This is a well-known technique in object detection (Lin et al., "Focal Loss for Dense Object Detection", 2017) but has not been applied to language model pretraining in this competition. The intuition transfers perfectly: just as object detection has a foreground/background class imbalance, language modeling has an easy/hard token imbalance. + +## Results + +All experiments run on **single RTX 4000 Ada 20GB** ($0.20/hr) using the **baseline `train_gpt.py`** code (no SOTA optimizations). + +### Focal Loss Gamma Sweep (with Cosine LR) + +| Gamma | 3000 steps | 5000 steps | Delta vs baseline (5000) | +|-------|-----------|-----------|--------------------------| +| 0 (cosine only) | 1.6538 | 1.5706 | -0.072 | +| 1 | 1.5643 | — | — | +| 1.5 | 1.5255 | — | — | +| 2 | 1.4849 | 1.4045 | -0.238 | +| 3 | 1.4246 | 1.3496 | -0.293 | +| 5 | 1.3339 | 1.2617 | -0.381 | +| 8 | 1.2288 | 1.1604 | -0.482 | +| **8 + asym** | — | **1.1567** | **-0.486** | +| 10 | 1.1806 | — | — | + +**Baseline at 5000 steps (no cosine, no focal): 1.6422 BPB** + +### Isolating Each Technique (3000 steps) + +| Technique | BPB | Delta vs baseline (1.7233) | +|-----------|-----|---------------------------| +| Baseline (linear warmdown) | 1.7233 | — | +| Cosine LR only | 1.6538 | -0.070 | +| Focal γ=2 only (no cosine) | 1.5647 | -0.159 | +| Cosine + Focal γ=2 | 1.4849 | -0.238 | + +Both techniques contribute independently and stack cleanly. + +### Cosine LR Scaling Validation + +| Steps | Baseline | Cosine LR | Delta | +|-------|----------|-----------|-------| +| 1000 | 2.0568 | 1.9334 | -0.123 | +| 2000 | 1.8330 | 1.8050 | -0.028 | +| 3000 | 1.7233 | 1.6538 | -0.070 | +| 5000 | 1.6422 | 1.5706 | -0.072 | + +Cosine LR advantage is consistent and not diminishing with training length. + +### Additional Findings + +**What helped:** +- Asymmetric 1/10 encoder-decoder split: -0.004 to -0.009 on top of other techniques (see PR #1275) +- Higher focal gamma: monotonically better up to γ=10 + +**What did NOT help:** +- Higher base LR (0.08) with cosine: +0.050 worse +- Lower min_lr_frac (0.01, 0.0): worse than default 0.1 +- Cosine with warm restarts: worse than plain cosine +- Label smoothing: hurt significantly +- Gradient noise: hurt significantly +- Weight decay scheduling: destroyed performance +- Gradient clipping with cosine: slightly worse + +## Why This Works + +Focal loss addresses the **easy token dominance** problem in language modeling: + +1. **Token frequency imbalance**: Common tokens like "the", "is", "and" are easy to predict (high pt). They dominate the gradient in standard CE but contribute little information. +2. **Capacity allocation**: With focal loss, the model allocates more of its limited capacity (16MB budget) to learning hard patterns — rare words, complex syntax, domain-specific terms. +3. **Implicit curriculum**: Higher gamma creates a natural curriculum where the model progressively focuses on harder examples as easy ones are mastered. + +The monotonic improvement with gamma suggests the baseline code significantly under-allocates capacity to hard tokens. + +## Experimental Setup + +- **GPU**: Single RTX 4000 Ada 20GB (RunPod, $0.20/hr) +- **Code**: Baseline `train_gpt.py` with FA2 patch (no SOTA optimizations) +- **Batch**: TRAIN_BATCH_TOKENS=8192, GRAD_ACCUM_STEPS=64 (effective 524K) +- **Evaluation**: `final_int8_zlib_roundtrip_exact` (the competition metric) +- **Total experiments**: 55+ across 13 rounds +- **Total GPU cost**: ~$2.50 + +## Caveats and Open Questions + +1. **Not validated on 8xH100**: These results are on a single GPU with small micro-batch. The optimal gamma may differ at the 8xH100 scale with larger batch sizes. +2. **Not tested on SOTA stack**: The SOTA code uses EMA, SWA, QAT, Muon, TTT, and other techniques. Focal loss may interact differently with these. +3. **High gamma concerns**: At γ=8, tokens predicted with 50% probability get weighted at 1/256 of normal. This aggressive down-weighting could cause underfitting on common patterns at very long training. +4. **Needs 8xH100 validation**: Requesting GPU credits to validate on the full competition setup. + +## Prior Work + +- **PR #1275**: Asymmetric 1/10 encoder-decoder split finding + 8xH100 partial run (1.1492 pre-quant BPB) +- **PR #1073**: 27 systematic experiments on M4 MacBook (deep supervision, LR tuning, batch scaling, architecture) + +## Reproduce + +```bash +git clone https://github.com/openai/parameter-golf.git && cd parameter-golf +pip install sentencepiece huggingface-hub datasets tiktoken flash-attn + +# Apply focal loss to any train_gpt.py — change the loss computation: +# OLD: return F.cross_entropy(logits.float(), targets, reduction="mean") +# NEW: +# focal_gamma = float(os.environ.get("FOCAL_GAMMA", "0")) +# if focal_gamma > 0: +# ce = F.cross_entropy(logits.float(), targets, reduction="none") +# pt = torch.exp(-ce) +# focal_weight = (1 - pt) ** focal_gamma +# return (focal_weight * ce).mean() +# return F.cross_entropy(logits.float(), targets, reduction="mean") + +# Also add cosine LR schedule in lr_mul(): +# min_lr_frac = 0.1 +# progress = step / max(args.iterations, 1) +# return min_lr_frac + 0.5 * (1.0 - min_lr_frac) * (1.0 + math.cos(math.pi * progress)) + +# Run with focal loss + cosine LR: +python data/cached_challenge_fineweb.py --variant sp1024 +COSINE_LR=1 FOCAL_GAMMA=8 ITERATIONS=5000 python train_gpt.py +``` diff --git a/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/submission.json b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/submission.json new file mode 100644 index 0000000000..1dc8e04979 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/submission.json @@ -0,0 +1,14 @@ +{ + "author": "Rana Usman", + "github_id": "ranausmanai", + "name": "Focal Loss for Language Model Pretraining: 1.1567 int8 BPB on RTX 4000 Ada baseline code", + "blurb": "Applying focal loss (Lin et al. 2017) to LM pretraining: replace cross_entropy with (1-pt)^gamma * CE. Combined with cosine LR schedule, achieves 1.1567 int8 BPB at 5000 steps on single RTX 4000 Ada using baseline code — within 0.037 of the 8xH100 SOTA record (1.1194). Monotonic improvement from gamma=1 to gamma=8. Focal loss alone gives -0.159 BPB at 3000 steps; combined with cosine LR gives -0.485 vs baseline at 5000 steps. Novel application: focal loss has not been applied to LM pretraining in this competition. See PRs #1275 and #1073 for prior experiments.", + "date": "2026-04-05T00:00:00Z", + "track": "non-record-16mb", + "gpu": "RTX 4000 Ada 20GB (RunPod, $0.20/hr)", + "best_int8_bpb": 1.1567, + "best_config": "cosine_lr + focal_gamma=8 + asymmetric_1_10", + "steps": 5000, + "step_avg_ms": 240, + "total_experiments": 55 +} From 78321759c3ebee8296ac399a2dc589e05df661cb Mon Sep 17 00:00:00 2001 From: ranausmanai Date: Sun, 5 Apr 2026 17:40:21 +0500 Subject: [PATCH 2/5] Add full experiment log (55+ experiments, 13 rounds) and 8xH100 run script Expanded README with complete experiment journey across M4 MacBook, RTX 5090, 8xH100, and RTX 4000 Ada. Added ready-to-run reproduction instructions for both single GPU and 8xH100 record runs. Co-Authored-By: Claude Opus 4.6 --- .../README.md | 124 +++++++++++++++++- 1 file changed, 118 insertions(+), 6 deletions(-) diff --git a/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md index 470b3456ce..44b3d1f427 100644 --- a/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md +++ b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md @@ -88,14 +88,92 @@ Focal loss addresses the **easy token dominance** problem in language modeling: The monotonic improvement with gamma suggests the baseline code significantly under-allocates capacity to hard tokens. +## Full Experiment Log (55+ experiments, 13 rounds) + +This finding came from systematic experimentation across 3 GPUs over multiple days: + +### Phase 1: M4 MacBook Exploration (27 experiments, PR #1073) +Using MLX on M4 MacBook 16GB to rapidly explore the design space: +- Deep supervision (auxiliary losses at intermediate layers): -0.05 BPB at small batch, vanishes at large batch +- LR sweep: LR 0.08 beats default 0.04 by -0.025 at 300 steps +- Gradient clipping: -0.019 +- EMA/SWA: hurt at 300 steps, help at 9000 (consistent with leaderboard) +- Batch scaling analysis across 4K-524K tokens + +### Phase 2: RTX 5090 + 8xH100 Validation (PR #1275) +- **Asymmetric encoder-decoder split** (num_encoder_layers=1): monotonic improvement across all configs +- RTX 5090 baseline sweep: 5/6→3/8→2/9→1/10 split, each better (-0.016 BPB total) +- SOTA code validation: -0.004 on PR #549 stack +- **8xH100 partial run**: 1.1492 pre-quant BPB at step 5666/9000 (pod crashed before final eval) + +### Phase 3: RTX 4000 Ada Deep Dive (55+ experiments, this PR) +Systematic sweep on $0.20/hr GPU, testing dozens of techniques at 1000-5000 steps: + +**Round 1-3**: Baseline calibration, asymmetric split validation at 500-1000 steps +- Confirmed asymmetric 1/10 helps on CUDA (-0.004 at 500 steps) +- Decoder-to-decoder skip connections: -0.005 at 500 steps, vanished at 1000 + +**Round 4-5**: Architecture & convergence techniques (1000 steps each) +- Decoder-skip 3-back connections: marginal, not robust +- Stochastic depth: torch.compile issues, inconclusive +- Enriched x0 (encoder feeds back to x0): no effect +- Skip connection dropout: compile issues +- Label smoothing: hurt significantly +- Gradient noise: hurt significantly +- Weight decay scheduling: destroyed performance + +**Round 6**: Longer training validation (2000-3000 steps) +- **Cosine LR schedule discovered**: -0.123 at 1000 steps, -0.070 at 3000 steps +- Cosine with warm restarts: worse than plain cosine +- Gradient clipping + cosine: slightly worse + +**Round 7**: Cosine LR stacking (2000-5000 steps) +- Cosine + asymmetric: -0.030 at 2000 steps (they stack!) +- Cosine + LR 0.06: -0.029 at 2000 steps +- Cosine + LR 0.08: +0.050 WORSE (too high for cosine) +- **Cosine at 5000 steps: 1.5706 vs baseline 1.6422 (-0.072, gap holds!)** + +**Round 8**: Baselines and min_lr_frac sweep +- Baseline 5000 steps: 1.6422 (critical reference point) +- Cosine + asymmetric 5000 steps: 1.5619 +- min_lr_frac=0.01: worse than 0.1 +- min_lr_frac=0.0: even worse + +**Round 9**: Novel techniques (focal loss, attention temp, lookahead) +- **FOCAL LOSS DISCOVERED**: γ=2 gives -0.169 at 3000 steps! +- Focal γ=1: -0.090 +- Attention temperature annealing: crashed (torch.compile scope issue) +- Lookahead optimizer: crashed (same issue) + +**Round 10**: Focal loss deep dive +- Focal γ=2 without cosine: 1.5647 (-0.159 alone, independently powerful) +- Focal γ=2 with cosine at 5000 steps: 1.4045 +- Focal γ=3: 1.4246 at 3000 steps (still improving!) +- Focal γ=1.5: 1.5255 + +**Round 11**: Higher gamma sweep +- Focal γ=3 at 5000 steps: 1.3496 +- Focal γ=4: 1.3845 at 3000 steps +- Focal γ=5: 1.3339 at 3000 steps, 1.2617 at 5000 steps +- **Still monotonically improving!** + +**Round 12**: Push to the limit +- Focal γ=8 at 3000 steps: 1.2288 +- **Focal γ=8 at 5000 steps: 1.1604** (approaching SOTA record!) +- Focal γ=8 + asymmetric 5000 steps: **1.1567** (our best) + +**Round 13**: Ceiling test +- Focal γ=10 at 3000 steps: 1.1806 (still improving vs γ=8 at 3000) +- Focal γ=8 + asymmetric at 5000 steps: 1.1567 (confirmed) + ## Experimental Setup - **GPU**: Single RTX 4000 Ada 20GB (RunPod, $0.20/hr) - **Code**: Baseline `train_gpt.py` with FA2 patch (no SOTA optimizations) - **Batch**: TRAIN_BATCH_TOKENS=8192, GRAD_ACCUM_STEPS=64 (effective 524K) - **Evaluation**: `final_int8_zlib_roundtrip_exact` (the competition metric) -- **Total experiments**: 55+ across 13 rounds -- **Total GPU cost**: ~$2.50 +- **Total experiments**: 55+ across 13 rounds, 3 GPUs (M4, RTX 5090, RTX 4000 Ada) +- **Total GPU cost**: ~$2.50 for RTX 4000 Ada experiments, ~$16 for 8xH100 run (PR #1275) ## Caveats and Open Questions @@ -111,11 +189,13 @@ The monotonic improvement with gamma suggests the baseline code significantly un ## Reproduce +### Quick test (any single GPU) + ```bash git clone https://github.com/openai/parameter-golf.git && cd parameter-golf pip install sentencepiece huggingface-hub datasets tiktoken flash-attn -# Apply focal loss to any train_gpt.py — change the loss computation: +# Apply focal loss to train_gpt.py — change the loss computation in GPT.forward(): # OLD: return F.cross_entropy(logits.float(), targets, reduction="mean") # NEW: # focal_gamma = float(os.environ.get("FOCAL_GAMMA", "0")) @@ -126,12 +206,44 @@ pip install sentencepiece huggingface-hub datasets tiktoken flash-attn # return (focal_weight * ce).mean() # return F.cross_entropy(logits.float(), targets, reduction="mean") -# Also add cosine LR schedule in lr_mul(): +# Also replace the lr_mul() function body with cosine schedule: # min_lr_frac = 0.1 # progress = step / max(args.iterations, 1) # return min_lr_frac + 0.5 * (1.0 - min_lr_frac) * (1.0 + math.cos(math.pi * progress)) -# Run with focal loss + cosine LR: +# And for asymmetric split, change in GPT.__init__(): +# self.num_encoder_layers = 1 # instead of num_layers // 2 + +# Download data python data/cached_challenge_fineweb.py --variant sp1024 -COSINE_LR=1 FOCAL_GAMMA=8 ITERATIONS=5000 python train_gpt.py + +# Run (single GPU, ~20 min for 5000 steps on RTX 4000 Ada) +FOCAL_GAMMA=8 COSINE_LR=1 ITERATIONS=5000 python train_gpt.py +``` + +### 8xH100 Record Run + +```bash +#!/bin/bash +# Full competition run on 8xH100 SXM +# Apply the same 3 changes to the SOTA train_gpt.py (PR #549 stack): +# 1. Focal loss in GPT.forward() (see above) +# 2. Cosine LR in lr_mul() (see above) +# 3. self.num_encoder_layers = 1 in GPT.__init__() + +cd /workspace +git clone --depth 1 https://github.com/openai/parameter-golf.git && cd parameter-golf +pip install -q sentencepiece huggingface-hub datasets tiktoken +pip install -q flash-attn --no-build-isolation + +python data/cached_challenge_fineweb.py --variant sp1024 + +# Copy SOTA script and apply changes +cp records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_gpt.py train_gpt_focal.py +# Apply: (1) focal loss, (2) cosine LR, (3) asymmetric split +# See code changes above + +# Run with competition settings +NUM_LAYERS=11 FOCAL_GAMMA=8 \ +torchrun --standalone --nproc_per_node=8 train_gpt_focal.py ``` From b0863f68319fd85ef6add4db34b2b7f3781e394f Mon Sep 17 00:00:00 2001 From: ranausmanai Date: Sun, 5 Apr 2026 20:02:01 +0500 Subject: [PATCH 3/5] Add SOTA stack validation: focal loss gives -0.289 pre-quant BPB on SOTA code Tested focal loss + cosine LR + asymmetric split on the actual SOTA train_gpt.py (LeakyReLU + XSA + Parallel Muon + EMA). Result: 1.5035 pre-quant val_bpb vs 1.7927 baseline (-0.289). Confirms focal loss transfers to the fully-optimized stack. Co-Authored-By: Claude Opus 4.6 --- .../README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md index 44b3d1f427..531513901b 100644 --- a/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md +++ b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md @@ -166,6 +166,18 @@ Systematic sweep on $0.20/hr GPU, testing dozens of techniques at 1000-5000 step - Focal γ=10 at 3000 steps: 1.1806 (still improving vs γ=8 at 3000) - Focal γ=8 + asymmetric at 5000 steps: 1.1567 (confirmed) +### Phase 4: SOTA Stack Validation (PR #549 code) + +Applied focal loss + cosine LR + asymmetric split to the **actual SOTA train_gpt.py** (LeakyReLU² + XSA + Parallel Muon + EMA + gated attention). Ran on single RTX 4000 Ada with FA2 fallback: + +| Config | Pre-quant val_bpb | int6 roundtrip BPB | +|--------|-------------------|-------------------| +| SOTA baseline | 1.7927 | 1.7817 | +| **SOTA + Focal γ=2 + Cosine + Asym** | **1.5035** | **1.5926** | +| **Delta** | **-0.289** | **-0.189** | + +**Focal loss transfers to the SOTA stack.** The pre-quant improvement (-0.289) is even larger than on baseline code. Only γ=2 was tested on SOTA due to slow eval pipeline on single GPU (sliding window eval takes 30+ min per experiment). Higher gamma values would likely show further improvement based on baseline code trends. + ## Experimental Setup - **GPU**: Single RTX 4000 Ada 20GB (RunPod, $0.20/hr) From 41f223020d988441b2d27dafba7c8ae5e4f54384 Mon Sep 17 00:00:00 2001 From: ranausmanai Date: Sun, 5 Apr 2026 22:52:33 +0500 Subject: [PATCH 4/5] Correct focal loss eval bug, add train_gpt.py, update results Focal loss was applied during eval (not just training), inflating all reported BPB numbers. Fixed with `self.training` check. Corrected multi-seed results show focal loss does not help. Cosine LR (-0.070) and asymmetric split remain valid findings. Co-Authored-By: Claude Opus 4.6 --- .../README.md | 257 +--- .../train_gpt.py | 1140 +++++++++++++++++ 2 files changed, 1166 insertions(+), 231 deletions(-) create mode 100644 records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/train_gpt.py diff --git a/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md index 531513901b..47abda409d 100644 --- a/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md +++ b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/README.md @@ -1,58 +1,37 @@ -# Focal Loss for Language Model Pretraining: 1.1567 int8 BPB on RTX 4000 Ada +# Cosine LR Schedule: -0.070 BPB + Focal Loss Investigation (Corrected) ## TL;DR -Applying **focal loss** (Lin et al. 2017) to language model pretraining gives massive, monotonic improvements. Combined with cosine LR scheduling and asymmetric encoder-decoder split, achieves **1.1567 int8 BPB at 5000 steps on a single RTX 4000 Ada** using the baseline `train_gpt.py` — within 0.037 of the 8xH100 SOTA record (1.1194 BPB). +**Cosine LR schedule** replaces linear warmdown and gives **-0.070 BPB** consistently across training lengths. Combined with asymmetric 1/10 split (PR #1275), gives -0.080 at 5000 steps. -## The Core Finding +We also investigated **focal loss** which initially showed massive gains but contained a **critical eval bug** — all focal loss numbers were wrong. Corrected results show focal loss does not help. Documented here as a cautionary tale. -Standard cross-entropy treats all tokens equally. But in natural language, most tokens are easy to predict (articles, common words, punctuation). Focal loss down-weights easy tokens and focuses the model on hard-to-predict tokens: +## Correction: Focal Loss Was An Eval Artifact -```python -# Standard cross-entropy: -loss = F.cross_entropy(logits, targets, reduction="mean") - -# Focal loss (our change): -ce = F.cross_entropy(logits, targets, reduction="none") -pt = torch.exp(-ce) # probability of correct class -focal_weight = (1 - pt) ** gamma -loss = (focal_weight * ce).mean() -``` - -This is a well-known technique in object detection (Lin et al., "Focal Loss for Dense Object Detection", 2017) but has not been applied to language model pretraining in this competition. The intuition transfers perfectly: just as object detection has a foreground/background class imbalance, language modeling has an easy/hard token imbalance. - -## Results - -All experiments run on **single RTX 4000 Ada 20GB** ($0.20/hr) using the **baseline `train_gpt.py`** code (no SOTA optimizations). - -### Focal Loss Gamma Sweep (with Cosine LR) +Our focal loss implementation applied `(1-pt)^gamma` weighting in `GPT.forward()`, called during both training AND evaluation. The "improvement" was entirely from down-weighting hard tokens in the eval metric, not from better model quality. -| Gamma | 3000 steps | 5000 steps | Delta vs baseline (5000) | -|-------|-----------|-----------|--------------------------| -| 0 (cosine only) | 1.6538 | 1.5706 | -0.072 | -| 1 | 1.5643 | — | — | -| 1.5 | 1.5255 | — | — | -| 2 | 1.4849 | 1.4045 | -0.238 | -| 3 | 1.4246 | 1.3496 | -0.293 | -| 5 | 1.3339 | 1.2617 | -0.381 | -| 8 | 1.2288 | 1.1604 | -0.482 | -| **8 + asym** | — | **1.1567** | **-0.486** | -| 10 | 1.1806 | — | — | +**Bug:** `if focal_gamma > 0:` (always active) +**Fix:** `if focal_gamma > 0 and self.training:` (training only) -**Baseline at 5000 steps (no cosine, no focal): 1.6422 BPB** +**Corrected results (3000 steps, multi-seed):** -### Isolating Each Technique (3000 steps) +| Config | Seed 1337 | Seed 42 | Seed 2025 | Mean | +|--------|-----------|---------|-----------|------| +| Cosine LR only | 1.6538 | 1.6480 | 1.6687 | **1.657** | +| Cosine + Focal γ=2 | 1.6612 | 1.6560 | 1.6594 | **1.659** | +| Cosine + Focal γ=5 | 1.6858 | — | — | 1.686 | +| Cosine + Focal γ=8 | 1.7124 | — | — | 1.712 | -| Technique | BPB | Delta vs baseline (1.7233) | -|-----------|-----|---------------------------| -| Baseline (linear warmdown) | 1.7233 | — | -| Cosine LR only | 1.6538 | -0.070 | -| Focal γ=2 only (no cosine) | 1.5647 | -0.159 | -| Cosine + Focal γ=2 | 1.4849 | -0.238 | +Focal loss does not help. Higher gamma actively hurts. -Both techniques contribute independently and stack cleanly. +## What IS Real: Cosine LR Schedule -### Cosine LR Scaling Validation +```python +# Replace linear warmdown in lr_mul(): +min_lr_frac = 0.1 +progress = step / max(args.iterations, 1) +return min_lr_frac + 0.5 * (1.0 - min_lr_frac) * (1.0 + math.cos(math.pi * progress)) +``` | Steps | Baseline | Cosine LR | Delta | |-------|----------|-----------|-------| @@ -61,201 +40,17 @@ Both techniques contribute independently and stack cleanly. | 3000 | 1.7233 | 1.6538 | -0.070 | | 5000 | 1.6422 | 1.5706 | -0.072 | -Cosine LR advantage is consistent and not diminishing with training length. - -### Additional Findings - -**What helped:** -- Asymmetric 1/10 encoder-decoder split: -0.004 to -0.009 on top of other techniques (see PR #1275) -- Higher focal gamma: monotonically better up to γ=10 - -**What did NOT help:** -- Higher base LR (0.08) with cosine: +0.050 worse -- Lower min_lr_frac (0.01, 0.0): worse than default 0.1 -- Cosine with warm restarts: worse than plain cosine -- Label smoothing: hurt significantly -- Gradient noise: hurt significantly -- Weight decay scheduling: destroyed performance -- Gradient clipping with cosine: slightly worse - -## Why This Works - -Focal loss addresses the **easy token dominance** problem in language modeling: - -1. **Token frequency imbalance**: Common tokens like "the", "is", "and" are easy to predict (high pt). They dominate the gradient in standard CE but contribute little information. -2. **Capacity allocation**: With focal loss, the model allocates more of its limited capacity (16MB budget) to learning hard patterns — rare words, complex syntax, domain-specific terms. -3. **Implicit curriculum**: Higher gamma creates a natural curriculum where the model progressively focuses on harder examples as easy ones are mastered. - -The monotonic improvement with gamma suggests the baseline code significantly under-allocates capacity to hard tokens. - -## Full Experiment Log (55+ experiments, 13 rounds) - -This finding came from systematic experimentation across 3 GPUs over multiple days: - -### Phase 1: M4 MacBook Exploration (27 experiments, PR #1073) -Using MLX on M4 MacBook 16GB to rapidly explore the design space: -- Deep supervision (auxiliary losses at intermediate layers): -0.05 BPB at small batch, vanishes at large batch -- LR sweep: LR 0.08 beats default 0.04 by -0.025 at 300 steps -- Gradient clipping: -0.019 -- EMA/SWA: hurt at 300 steps, help at 9000 (consistent with leaderboard) -- Batch scaling analysis across 4K-524K tokens - -### Phase 2: RTX 5090 + 8xH100 Validation (PR #1275) -- **Asymmetric encoder-decoder split** (num_encoder_layers=1): monotonic improvement across all configs -- RTX 5090 baseline sweep: 5/6→3/8→2/9→1/10 split, each better (-0.016 BPB total) -- SOTA code validation: -0.004 on PR #549 stack -- **8xH100 partial run**: 1.1492 pre-quant BPB at step 5666/9000 (pod crashed before final eval) - -### Phase 3: RTX 4000 Ada Deep Dive (55+ experiments, this PR) -Systematic sweep on $0.20/hr GPU, testing dozens of techniques at 1000-5000 steps: - -**Round 1-3**: Baseline calibration, asymmetric split validation at 500-1000 steps -- Confirmed asymmetric 1/10 helps on CUDA (-0.004 at 500 steps) -- Decoder-to-decoder skip connections: -0.005 at 500 steps, vanished at 1000 - -**Round 4-5**: Architecture & convergence techniques (1000 steps each) -- Decoder-skip 3-back connections: marginal, not robust -- Stochastic depth: torch.compile issues, inconclusive -- Enriched x0 (encoder feeds back to x0): no effect -- Skip connection dropout: compile issues -- Label smoothing: hurt significantly -- Gradient noise: hurt significantly -- Weight decay scheduling: destroyed performance - -**Round 6**: Longer training validation (2000-3000 steps) -- **Cosine LR schedule discovered**: -0.123 at 1000 steps, -0.070 at 3000 steps -- Cosine with warm restarts: worse than plain cosine -- Gradient clipping + cosine: slightly worse - -**Round 7**: Cosine LR stacking (2000-5000 steps) -- Cosine + asymmetric: -0.030 at 2000 steps (they stack!) -- Cosine + LR 0.06: -0.029 at 2000 steps -- Cosine + LR 0.08: +0.050 WORSE (too high for cosine) -- **Cosine at 5000 steps: 1.5706 vs baseline 1.6422 (-0.072, gap holds!)** - -**Round 8**: Baselines and min_lr_frac sweep -- Baseline 5000 steps: 1.6422 (critical reference point) -- Cosine + asymmetric 5000 steps: 1.5619 -- min_lr_frac=0.01: worse than 0.1 -- min_lr_frac=0.0: even worse - -**Round 9**: Novel techniques (focal loss, attention temp, lookahead) -- **FOCAL LOSS DISCOVERED**: γ=2 gives -0.169 at 3000 steps! -- Focal γ=1: -0.090 -- Attention temperature annealing: crashed (torch.compile scope issue) -- Lookahead optimizer: crashed (same issue) +Consistent, not diminishing with training length. -**Round 10**: Focal loss deep dive -- Focal γ=2 without cosine: 1.5647 (-0.159 alone, independently powerful) -- Focal γ=2 with cosine at 5000 steps: 1.4045 -- Focal γ=3: 1.4246 at 3000 steps (still improving!) -- Focal γ=1.5: 1.5255 +## What IS Real: Asymmetric 1/10 Split -**Round 11**: Higher gamma sweep -- Focal γ=3 at 5000 steps: 1.3496 -- Focal γ=4: 1.3845 at 3000 steps -- Focal γ=5: 1.3339 at 3000 steps, 1.2617 at 5000 steps -- **Still monotonically improving!** - -**Round 12**: Push to the limit -- Focal γ=8 at 3000 steps: 1.2288 -- **Focal γ=8 at 5000 steps: 1.1604** (approaching SOTA record!) -- Focal γ=8 + asymmetric 5000 steps: **1.1567** (our best) - -**Round 13**: Ceiling test -- Focal γ=10 at 3000 steps: 1.1806 (still improving vs γ=8 at 3000) -- Focal γ=8 + asymmetric at 5000 steps: 1.1567 (confirmed) - -### Phase 4: SOTA Stack Validation (PR #549 code) - -Applied focal loss + cosine LR + asymmetric split to the **actual SOTA train_gpt.py** (LeakyReLU² + XSA + Parallel Muon + EMA + gated attention). Ran on single RTX 4000 Ada with FA2 fallback: - -| Config | Pre-quant val_bpb | int6 roundtrip BPB | -|--------|-------------------|-------------------| -| SOTA baseline | 1.7927 | 1.7817 | -| **SOTA + Focal γ=2 + Cosine + Asym** | **1.5035** | **1.5926** | -| **Delta** | **-0.289** | **-0.189** | - -**Focal loss transfers to the SOTA stack.** The pre-quant improvement (-0.289) is even larger than on baseline code. Only γ=2 was tested on SOTA due to slow eval pipeline on single GPU (sliding window eval takes 30+ min per experiment). Higher gamma values would likely show further improvement based on baseline code trends. - -## Experimental Setup - -- **GPU**: Single RTX 4000 Ada 20GB (RunPod, $0.20/hr) -- **Code**: Baseline `train_gpt.py` with FA2 patch (no SOTA optimizations) -- **Batch**: TRAIN_BATCH_TOKENS=8192, GRAD_ACCUM_STEPS=64 (effective 524K) -- **Evaluation**: `final_int8_zlib_roundtrip_exact` (the competition metric) -- **Total experiments**: 55+ across 13 rounds, 3 GPUs (M4, RTX 5090, RTX 4000 Ada) -- **Total GPU cost**: ~$2.50 for RTX 4000 Ada experiments, ~$16 for 8xH100 run (PR #1275) - -## Caveats and Open Questions - -1. **Not validated on 8xH100**: These results are on a single GPU with small micro-batch. The optimal gamma may differ at the 8xH100 scale with larger batch sizes. -2. **Not tested on SOTA stack**: The SOTA code uses EMA, SWA, QAT, Muon, TTT, and other techniques. Focal loss may interact differently with these. -3. **High gamma concerns**: At γ=8, tokens predicted with 50% probability get weighted at 1/256 of normal. This aggressive down-weighting could cause underfitting on common patterns at very long training. -4. **Needs 8xH100 validation**: Requesting GPU credits to validate on the full competition setup. - -## Prior Work - -- **PR #1275**: Asymmetric 1/10 encoder-decoder split finding + 8xH100 partial run (1.1492 pre-quant BPB) -- **PR #1073**: 27 systematic experiments on M4 MacBook (deep supervision, LR tuning, batch scaling, architecture) +`self.num_encoder_layers = 1` — see PR #1275 for full details. Stacks with cosine: 1.5619 at 5000 steps (vs 1.5706 cosine alone). ## Reproduce -### Quick test (any single GPU) - ```bash git clone https://github.com/openai/parameter-golf.git && cd parameter-golf pip install sentencepiece huggingface-hub datasets tiktoken flash-attn - -# Apply focal loss to train_gpt.py — change the loss computation in GPT.forward(): -# OLD: return F.cross_entropy(logits.float(), targets, reduction="mean") -# NEW: -# focal_gamma = float(os.environ.get("FOCAL_GAMMA", "0")) -# if focal_gamma > 0: -# ce = F.cross_entropy(logits.float(), targets, reduction="none") -# pt = torch.exp(-ce) -# focal_weight = (1 - pt) ** focal_gamma -# return (focal_weight * ce).mean() -# return F.cross_entropy(logits.float(), targets, reduction="mean") - -# Also replace the lr_mul() function body with cosine schedule: -# min_lr_frac = 0.1 -# progress = step / max(args.iterations, 1) -# return min_lr_frac + 0.5 * (1.0 - min_lr_frac) * (1.0 + math.cos(math.pi * progress)) - -# And for asymmetric split, change in GPT.__init__(): -# self.num_encoder_layers = 1 # instead of num_layers // 2 - -# Download data python data/cached_challenge_fineweb.py --variant sp1024 - -# Run (single GPU, ~20 min for 5000 steps on RTX 4000 Ada) -FOCAL_GAMMA=8 COSINE_LR=1 ITERATIONS=5000 python train_gpt.py -``` - -### 8xH100 Record Run - -```bash -#!/bin/bash -# Full competition run on 8xH100 SXM -# Apply the same 3 changes to the SOTA train_gpt.py (PR #549 stack): -# 1. Focal loss in GPT.forward() (see above) -# 2. Cosine LR in lr_mul() (see above) -# 3. self.num_encoder_layers = 1 in GPT.__init__() - -cd /workspace -git clone --depth 1 https://github.com/openai/parameter-golf.git && cd parameter-golf -pip install -q sentencepiece huggingface-hub datasets tiktoken -pip install -q flash-attn --no-build-isolation - -python data/cached_challenge_fineweb.py --variant sp1024 - -# Copy SOTA script and apply changes -cp records/track_10min_16mb/2026-03-23_LeakyReLU_LegalTTT_ParallelMuon/train_gpt.py train_gpt_focal.py -# Apply: (1) focal loss, (2) cosine LR, (3) asymmetric split -# See code changes above - -# Run with competition settings -NUM_LAYERS=11 FOCAL_GAMMA=8 \ -torchrun --standalone --nproc_per_node=8 train_gpt_focal.py +COSINE_LR=1 python train_gpt.py ``` diff --git a/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/train_gpt.py b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/train_gpt.py new file mode 100644 index 0000000000..04088e971e --- /dev/null +++ b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/train_gpt.py @@ -0,0 +1,1140 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + # Focal loss: down-weight easy tokens, focus on hard ones (training only) + focal_gamma = float(os.environ.get("FOCAL_GAMMA", "0")) + if focal_gamma > 0 and self.training: + ce = F.cross_entropy(logits.float(), targets, reduction="none") + pt = torch.exp(-ce) # probability of correct class + focal_weight = (1 - pt) ** focal_gamma + return (focal_weight * ce).mean() + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + import math + cosine_sched = int(os.environ.get("COSINE_LR", "0")) + if cosine_sched: + # Cosine annealing from 1.0 to 0.1 over full training + min_lr_frac = 0.1 + progress = step / max(args.iterations, 1) + return min_lr_frac + 0.5 * (1.0 - min_lr_frac) * (1.0 + math.cos(math.pi * progress)) + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() From 1a6be007d6631e8c2c00cc730d12ae207774d4e6 Mon Sep 17 00:00:00 2001 From: ranausmanai Date: Sun, 5 Apr 2026 23:02:56 +0500 Subject: [PATCH 5/5] Add novel experiment features: z-loss, token dropout, embedding mixup, gamma annealing, logit penalty, multi-cycle cosine All modifications are training-only (guarded by self.training). Controlled via env vars: - Z_LOSS: log(Z)^2 regularizer from PaLM paper - TOKEN_DROP: synaptic pruning-inspired token dropout - EMBED_MIXUP: genetic recombination-inspired embedding interpolation - GAMMA_ANNEAL: decay focal gamma to 0 over training - LOGIT_PENALTY: L2 penalty on logits for sparse activation - COSINE_CYCLES: multi-cycle cosine LR schedule Co-Authored-By: Claude Opus 4.6 --- .../train_gpt.py | 39 +++++++++++++++++-- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/train_gpt.py b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/train_gpt.py index 04088e971e..3bfa081599 100644 --- a/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/train_gpt.py +++ b/records/track_non_record_16mb/2026-04-05_FocalLoss_CosineLR_LanguageModeling/train_gpt.py @@ -699,7 +699,18 @@ def _init_weights(self) -> None: def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) + # Embedding mixup: interpolate adjacent token embeddings (genetic recombination) + embed_mixup = float(os.environ.get("EMBED_MIXUP", "0")) + if embed_mixup > 0 and self.training: + shifted = torch.roll(x, 1, dims=1) + alpha = embed_mixup * torch.rand(x.shape[0], x.shape[1], 1, device=x.device, dtype=x.dtype) + x = (1 - alpha) * x + alpha * shifted x = F.rms_norm(x, (x.size(-1),)) + # Token dropout: randomly zero out token representations (synaptic pruning) + token_drop = float(os.environ.get("TOKEN_DROP", "0")) + if token_drop > 0 and self.training: + mask = torch.rand(x.shape[0], x.shape[1], 1, device=x.device, dtype=x.dtype) > token_drop + x = x * mask / (1.0 - token_drop) x0 = x skips: list[Tensor] = [] @@ -723,12 +734,26 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) # Focal loss: down-weight easy tokens, focus on hard ones (training only) focal_gamma = float(os.environ.get("FOCAL_GAMMA", "0")) + # Gamma annealing: decay gamma from initial value to 0 over training + if hasattr(self, '_gamma_anneal') and self._gamma_anneal and hasattr(self, '_training_progress'): + focal_gamma = focal_gamma * (1.0 - self._training_progress) + # Z-loss regularizer (PaLM paper): penalize log(Z) to stabilize logits + z_loss_coeff = float(os.environ.get("Z_LOSS", "0")) + # Logit penalty: L2 penalty on logits to prevent overconfidence + logit_penalty = float(os.environ.get("LOGIT_PENALTY", "0")) if focal_gamma > 0 and self.training: ce = F.cross_entropy(logits.float(), targets, reduction="none") pt = torch.exp(-ce) # probability of correct class focal_weight = (1 - pt) ** focal_gamma - return (focal_weight * ce).mean() - return F.cross_entropy(logits.float(), targets, reduction="mean") + loss = (focal_weight * ce).mean() + else: + loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if z_loss_coeff > 0 and self.training: + log_z = torch.logsumexp(logits.float(), dim=-1) + loss = loss + z_loss_coeff * (log_z ** 2).mean() + if logit_penalty > 0 and self.training: + loss = loss + logit_penalty * (logits.float() ** 2).mean() + return loss # ----------------------------- @@ -928,14 +953,19 @@ def zero_grad_all() -> None: max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + # Gamma annealing setup + model._gamma_anneal = int(os.environ.get("GAMMA_ANNEAL", "0")) > 0 + def lr_mul(step: int, elapsed_ms: float) -> float: import math cosine_sched = int(os.environ.get("COSINE_LR", "0")) if cosine_sched: - # Cosine annealing from 1.0 to 0.1 over full training + # Cosine annealing from 1.0 to 0.1, supports multi-cycle min_lr_frac = 0.1 + num_cycles = int(os.environ.get("COSINE_CYCLES", "1")) progress = step / max(args.iterations, 1) - return min_lr_frac + 0.5 * (1.0 - min_lr_frac) * (1.0 + math.cos(math.pi * progress)) + cycle_progress = (progress * num_cycles) % 1.0 + return min_lr_frac + 0.5 * (1.0 - min_lr_frac) * (1.0 + math.cos(math.pi * cycle_progress)) if args.warmdown_iters <= 0: return 1.0 if max_wallclock_ms is None: @@ -1027,6 +1057,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + model._training_progress = step / max(args.iterations, 1) loss = model(x, y) train_loss += loss.detach() (loss * grad_scale).backward()