Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions records/track_10min_16mb/2026-04-07_idan3011_sp4096_10L_TTT/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# sp4096 Custom Tokenizer + 10L 3.5x MLP + GPTQ + Score-First TTT

**val_bpb: 1.1266** | **artifact: 15.99 MB** | **8xH100** | **600s wallclock**

## Headline metrics

| Stage | val_bpb | val_loss |
|-------|--------:|---------:|
| Pre-quant (step 5952) | 1.1427 | 2.6289 |
| Post-quant (int6+brotli roundtrip) | 1.1439 | 2.6318 |
| Sliding window (stride=64) | 1.1277 | — |
| **Score-first TTT (final)** | **1.1266** | — |

## Architecture

- 10 layers, 512 dim, 8 heads, 4 KV heads (GQA 2:1)
- 3.5x MLP expansion (1792 hidden dim per block)
- Tied input/output embeddings (single 4096×512 matrix)
- LeakyReLU(0.5)² MLP activation
- Logit softcap at 30 via tanh
- U-Net skip connections (encoder layers feed matching decoder layers via per-layer scale weights)
- Last 4 blocks use cross-sequence attention (XSA)
- 28.3M parameters total

## Custom sp4096 SentencePiece tokenizer

A 4096-vocab SentencePiece BPE tokenizer trained on FineWeb, hosted at `idan3011/parameter-golf-sp4096` on HuggingFace. The script auto-downloads the dataset + tokenizer on first run.

Compared to the default sp1024 tokenizer:
- **~26% fewer tokens per byte** (more efficient compression)
- More room in the param budget — fewer tokens means more "value" per parameter spent
- Required tuning the embedding quantization separately (see below)

## Training

- **786,432 token batch** (8 GPUs × 1 grad accum) — chosen over 524K for smoother warmdown trajectory
- **Muon optimizer** for matrix parameters in transformer blocks
- **Adam** for embeddings, output head, and scalar/vector parameters
- **Wallclock-fraction warmdown**: cosine LR decay over the last 35% of remaining wallclock (rather than fixed step count) to maximize useful training time
- **EMA** with decay 0.997 maintained throughout
- **SWA** averaging in the last 50% of training, blended with EMA at the end (weighted average of 198 checkpoints)
- **QAT** (fake-quantized weights during forward) on MLP CastedLinear layers to make the model more robust to int5 quantization
- Hit wallclock cap at step 5952/20000 in 600.054s
- Pre-quant val_loss: 2.6289, val_bpb: **1.1427**

## Quantization & compression

### GPTQ with AR self-generated calibration

Rather than calibrating GPTQ on a separate dataset, the trained model **generates its own calibration sequences** via autoregressive sampling. This produces 16 sequences of 512 tokens each, sampled from the model's own distribution — perfectly matched to its activation statistics.

GPTQ then uses Hessian-aware error compensation to quantize each weight column-by-column, propagating the rounding error to the remaining columns. This minimizes the L2 reconstruction error of the layer outputs.

### Mixed quantization scheme

| Tensor class | Bits | Reasoning |
|---|---|---|
| Attention weights (q/k/v/proj) | int5 per-row | Aggressive but stable with GPTQ |
| MLP weights (fc/proj) | int5 per-row | Stable with QAT during training |
| **tok_emb.weight (tied)** | **int8 per-row** | int5 destroys tied embedding (input AND output projection) — discovered painfully via experimentation. int8 yields near-zero quant gap. |
| Control tensors (scales, mixes, q_gain, skip_weights) | fp32 passthrough | Small total size, needed for stability |
| All other small tensors | fp16 passthrough | <65K elements |

Final post-quant gap is only **0.0012 BPB** (1.1427 → 1.1439) — exceptionally small for such aggressive quantization.

### brotli + byte-shuffle compression

After int5 quantization, weights are compressed with **brotli (quality 11)** instead of LZMA. To boost compression further:

- **Byte-shuffle pre-filter**: int8 quantized values are stored as little-endian int8s. Most values cluster near zero, meaning the high bytes are mostly zero/uniform. Reordering bytes column-wise (all-byte-0 then all-byte-1 then ...) groups the structure together and lets brotli's context modeling exploit it.

This combination saved ~280KB vs plain LZMA, and ~700KB vs naive int8 + zlib. Final artifact: **15,918,111 bytes** (model) + 71,265 (code) = **15,989,376 bytes total** — 10KB under the 16MB cap.

## Score-First TTT (eval-time adaptation)

Test-Time Training adapts the model on validation data **after scoring it**, exploiting val/train distribution shift. Strictly legal under issue #402 / issue #1017 — every token is scored before any weight update on it, single left-to-right pass.

### Algorithm

1. Build sliding-window scoring positions globally over the validation set (stride=64, seq_len=2048)
2. Group windows into chunks based on which `chunk_tok` block their **scored region** falls in
3. For each chunk:
- **Score** the chunk's windows under `inference_mode` (forward only) — accumulate L (loss), T (token count), B (byte count)
- If not the last chunk: **train** on the chunk's contiguous sequences via SGD + grad clipping, with **all_reduce** of gradients across GPUs
4. Final BPB = L / (B × ln 2)

### Hyperparameters

- 348 chunks of 131,072 tokens each
- 20 SGD epochs per chunk
- SGD lr=0.003 with cosine decay across chunks (chunk i uses lr × 0.5 × (1 + cos(π·i/N)))
- Momentum 0.9, no weight decay
- Grad clipping at 1.0
- Freeze: 0 blocks (all 10 layers trainable)

### Distributed implementation

- Windows split across ranks contiguously: `windows[rank*N/W:(rank+1)*N/W]`
- Each GPU scores its windows independently
- During training: each GPU processes its own sequence partition, then `dist.all_reduce(grad, AVG)` synchronizes gradients before optimizer step
- Final L/T/B all-reduced via SUM

### Result

Sliding baseline: **1.1277** → TTT: **1.1266** (-0.0011 BPB)

The improvement is small because the base model is already well-optimized (only 28M params, fully trained for 600s on FineWeb).

## Reproduce

```bash
pip install -r requirements.txt
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

Every hyperparameter is baked into the script. Data and tokenizer auto-download from HuggingFace on first run. No env vars, no shell scripts, no setup steps.

Total runtime: ~10 min training + ~7 min eval (post-quant + sliding + TTT) on 8xH100.

## Included files

- `train_gpt.py` — frozen training script (1387 lines)
- `train.log` — full training + eval log
- `submission.json` — leaderboard metadata
- `README.md` — this file
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"author": "idan3011",
"github_id": "idan3011",
"name": "sp4096 + 10L 3.5x MLP + GPTQ + Score-First TTT",
"blurb": "Custom sp4096 SentencePiece tokenizer + 10-layer tied-embedding GPT with 3.5x MLP, int5-all GPTQ quantization, brotli+byte-shuffle compression, and score-first TTT eval-time adaptation.",
"date": "2026-04-07T00:00:00Z",
"val_loss": 2.63175670,
"val_bpb": 1.12662548,
"pre_quant_val_loss": 2.6289,
"pre_quant_val_bpb": 1.1427,
"sliding_val_bpb": 1.12768974,
"ttt_val_bpb": 1.12662548,
"step_stop": 5952,
"wallclock_seconds": 600.054,
"bytes_total": 15989376,
"bytes_code": 71265,
"bytes_model_int6_brotli": 15918111
}
123 changes: 123 additions & 0 deletions records/track_10min_16mb/2026-04-07_idan3011_sp4096_10L_TTT/train.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
data:fineweb10B_sp4096 train_shards:86 val_tokens:45516800
model_params:28334672 world_size:8 grad_accum:1
batch:786432 seq:2048 warmup:20 wallclock:600s
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:8.3190 val_bpb:3.6160 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:8.3188 train_time:91ms step_avg:91.35ms
step:2/20000 train_loss:13.8742 train_time:175ms step_avg:87.26ms
step:3/20000 train_loss:9.2554 train_time:271ms step_avg:90.49ms
step:4/20000 train_loss:7.9238 train_time:368ms step_avg:92.06ms
step:5/20000 train_loss:7.9832 train_time:466ms step_avg:93.11ms
step:6/20000 train_loss:8.2685 train_time:563ms step_avg:93.76ms
step:7/20000 train_loss:7.9973 train_time:660ms step_avg:94.23ms
step:8/20000 train_loss:7.5139 train_time:757ms step_avg:94.67ms
step:9/20000 train_loss:7.3772 train_time:854ms step_avg:94.93ms
step:10/20000 train_loss:7.0031 train_time:952ms step_avg:95.20ms
step:200/20000 train_loss:3.3868 train_time:19750ms step_avg:98.75ms
step:400/20000 train_loss:3.0708 train_time:39679ms step_avg:99.20ms
step:600/20000 train_loss:3.0345 train_time:59635ms step_avg:99.39ms
step:800/20000 train_loss:3.0738 train_time:79663ms step_avg:99.58ms
step:1000/20000 train_loss:2.9608 train_time:99657ms step_avg:99.66ms
step:1000/20000 val_loss:2.9573 val_bpb:1.2854 train_time:99680ms step_avg:99.68ms
step:1200/20000 train_loss:2.9294 train_time:119705ms step_avg:99.75ms
step:1400/20000 train_loss:2.9476 train_time:139737ms step_avg:99.81ms
step:1600/20000 train_loss:2.9330 train_time:159693ms step_avg:99.81ms
step:1800/20000 train_loss:3.0283 train_time:179688ms step_avg:99.83ms
step:2000/20000 train_loss:2.7261 train_time:199636ms step_avg:99.82ms
step:2000/20000 val_loss:2.8506 val_bpb:1.2391 train_time:199659ms step_avg:99.83ms
step:2200/20000 train_loss:2.8482 train_time:219637ms step_avg:99.84ms
step:2400/20000 train_loss:2.8831 train_time:239556ms step_avg:99.81ms
step:2600/20000 train_loss:2.6709 train_time:259540ms step_avg:99.82ms
step:2800/20000 train_loss:2.7688 train_time:279524ms step_avg:99.83ms
step:3000/20000 train_loss:2.8072 train_time:299432ms step_avg:99.81ms
step:3000/20000 val_loss:2.8106 val_bpb:1.2217 train_time:299454ms step_avg:99.82ms
step:3200/20000 train_loss:2.7424 train_time:319390ms step_avg:99.81ms
step:3400/20000 train_loss:2.8552 train_time:339286ms step_avg:99.79ms
step:3600/20000 train_loss:2.7908 train_time:359257ms step_avg:99.79ms
step:3800/20000 train_loss:2.8224 train_time:379151ms step_avg:99.78ms
step:4000/20000 train_loss:2.7302 train_time:399129ms step_avg:99.78ms
step:4000/20000 val_loss:2.7829 val_bpb:1.2097 train_time:399151ms step_avg:99.79ms
step:4200/20000 train_loss:2.7838 train_time:419083ms step_avg:99.78ms
step:4400/20000 train_loss:2.9248 train_time:438982ms step_avg:99.77ms
step:4600/20000 train_loss:2.7707 train_time:458930ms step_avg:99.77ms
step:4800/20000 train_loss:2.6911 train_time:478829ms step_avg:99.76ms
step:5000/20000 train_loss:2.7248 train_time:499194ms step_avg:99.84ms
step:5000/20000 val_loss:2.7115 val_bpb:1.1786 train_time:499195ms step_avg:99.84ms
step:5200/20000 train_loss:2.5774 train_time:520236ms step_avg:100.05ms
step:5400/20000 train_loss:2.5983 train_time:541284ms step_avg:100.24ms
step:5600/20000 train_loss:2.6296 train_time:562313ms step_avg:100.41ms
step:5800/20000 train_loss:2.6004 train_time:583708ms step_avg:100.64ms
step:5952/20000 val_loss:2.6289 val_bpb:1.1427 train_time:600054ms step_avg:100.82ms
stopping_early: wallclock_cap train_time:600054ms step:5952/20000
peak memory allocated: 20258 MiB reserved: 20994 MiB
swa: averaging 198 checkpoints on top of EMA
ema: loading weights
gptq: generating AR calibration data...
gptq: AR gen done in 10.2s
gptq: collecting Hessians...
gptq: Hessians for 60 layers in 0.1s
gptq: quantized 60 layers in 4.2s, total 14.5s
Serialized model: 109181987 bytes Code: 71265 Total: 109253252
Serialized model int6+brotli: 15918111 bytes (payload:28489024 raw:28539065)
Total submission size: 15989376 bytes
final_int8_zlib_roundtrip val_loss:2.6318 val_bpb:1.1439 eval_time:25216ms
final_int8_zlib_roundtrip_exact val_loss:2.63175670 val_bpb:1.14393687
final_sliding_window val_bpb:1.1277 eval_time:140435ms
final_sliding_window_exact val_bpb:1.12768974
ttt: starting
ttt: chunk=0/348
ttt: chunk=10/348
ttt: chunk=20/348
ttt: chunk=30/348
ttt: chunk=40/348
ttt: chunk=50/348
ttt: chunk=60/348
ttt: chunk=70/348
ttt: chunk=80/348
ttt: chunk=90/348
ttt: chunk=100/348
ttt: chunk=110/348
ttt: chunk=120/348
ttt: chunk=130/348
ttt: chunk=140/348
ttt: chunk=150/348
ttt: chunk=160/348
ttt: chunk=170/348
ttt: chunk=180/348
ttt: chunk=190/348
ttt: chunk=200/348
ttt: chunk=210/348
ttt: chunk=220/348
ttt: chunk=230/348
ttt: chunk=240/348
ttt: chunk=250/348
ttt: chunk=260/348
ttt: chunk=270/348
ttt: chunk=280/348
ttt: chunk=290/348
ttt: chunk=300/348
ttt: chunk=310/348
ttt: chunk=320/348
ttt: chunk=330/348
ttt: chunk=340/348
final_ttt val_bpb:1.1266 eval_time:245233ms
final_ttt_exact val_bpb:1.12662548
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ setuptools
typing-extensions==4.15.0
datasets
tiktoken
sentencepiece
sentencepiece
brotli
Loading