Skip to content
109 changes: 109 additions & 0 deletions records/track_10min_16mb/2026-03-31_KitchenSinkV3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Record: Window Attention + Mixed Seq_Len Training

**val_bpb: 1.1108** (5-seed mean, std 0.0013) | **1.8755 nats** | **~15.73 MB** | 8xH100 SXM, 600s | No TTT

I started from [PR #1130](https://github.com/openai/parameter-golf/pull/1130) (KitchenSinkV2 Improved), which added split early/late LR banks, MiLe margin loss, cache+backout residual, residual lambdas, bigger bigram/VE, and FA3 on top of the PR #549 stack. On top of that, I ported the fused Triton MLP from [PR #1072](https://github.com/openai/parameter-golf/pull/1072) and the sigmoid-gated skips + brotli+byte-shuffle compression from [PR #1089](https://github.com/openai/parameter-golf/pull/1089). I also increased to 12 layers and tuned qk_gain to 2.5.

The two main contributions of this submission are window attention and mixed seq_len training, described below.

## Results (8xH100 80GB SXM, 600s, no TTT)

| Seed | Steps | ms/step | Post-EMA BPB | **Sliding BPB** | val_loss (nats) | Artifact |
|------|-------|---------|--------------|-----------------|-----------------|----------|
| 2 | 8,428 | 69.6 | 1.1250 | **1.1094** | 1.8731 | 15,726,762 |
| 1337 | 8,428 | 69.6 | 1.1250 | **1.1101** | 1.8742 | 15,721,698 |
| 42 | 8,428 | 69.6 | 1.1250 | **1.1103** | 1.8746 | 15,725,995 |
| 7 | 8,428 | 69.6 | 1.1250 | **1.1119** | 1.8773 | 15,723,346 |
| 22 | 8,428 | 69.6 | 1.1250 | **1.1126** | 1.8785 | 15,720,902 |
| **Mean** | | | | **1.1108** | **1.8755** | **15,723,741** |

Current merged SOTA ([2026-03-25 AR Self-Gen GPTQ + XSA-all + BigramHash 3072x112](https://github.com/openai/parameter-golf/blob/main/records/track_10min_16mb/2026-03-25_ValCalib_GPTQ_XSA_BigramHash3072/README.md)): **1.11473 BPB**.
Delta vs current merged SOTA: **-0.0039 BPB** (**-0.0066 nats**).

## Window attention

Instead of full causal attention on every layer, layers 2, 4, 6, 8, and 10 use a sliding window of 512 tokens via Flash Attention 3's `window_size` parameter. The remaining layers (0, 1, 3, 5, 7, 9, 11) keep full attention.

The motivation was to enable training at longer sequence lengths without proportionally increasing compute. Full quadratic attention at seq_len=6144 is expensive, but with window attention on 5 of 12 layers, those layers run in O(n * w) instead of O(n^2), cutting the per-step cost significantly. The layers with full attention still give the model access to the full context.

I swept several configurations: window sizes (256, 512, 1024), which layers to window (sparse, dense, even), and how many layers. Window 512 on even-indexed layers was the sweet spot — enough layers windowed to get the speedup, enough full-attention layers to preserve long-range modeling.

At seq_len=2048 (where all tokens fit in a 512-wide window anyway for most positions), windowed attention adds a small overhead (~2-3%). The benefit kicks in at longer sequences: 15% faster at 4096, 21% at 6144, 25% at 8192.

## Mixed seq_len training

Different GPUs train with different sequence lengths within the same step. In the final configuration, 5 GPUs train at seq_len=2048 and 3 GPUs train at seq_len=6144. The number of sequences per GPU is set so that the total ms per step stays roughly constant.

The idea came from noticing that the sliding-window eval (which uses long sequences) gave substantially better scores than the standard 2048-token eval, but training at long sequence lengths was slow. By having most GPUs train cheaply at 2048 and a few GPUs see long context at 6144, the model gets the best of both: high step throughput from the short-sequence GPUs and long-range learning from the long-sequence ones.

I ran an extensive sweep of seq_len combinations. Some findings:

- **3x2048 + 1x6144** (eval at 6144) gave the best int6 roundtrip BPB (**1.1292**) in 4-GPU experiments, beating both pure 4x2048 (1.1417) and pure 4x6144 (1.1360)
- Having at least one GPU on a long sequence (4096+) was critical for good quantized performance
- More short-sequence GPUs = more steps in the same wallclock, which helps training loss
- More long-sequence GPUs = better post-EMA loss, but fewer steps and worse quantization
- 8192 was too slow to be worthwhile — the step-time penalty outweighed the context benefit

For the final 8-GPU submission, I used 5x2048 + 3x6144, which balances throughput and long-context exposure.

## Other changes

- **12 layers** (up from 11) with split early/late LR banks
- **Sigmoid-gated skip connections** — `x += sigmoid(gate) * skip` replaces learned scalar skip weights
- **Fused Triton MLP** (PR #1105) — LeakyReLU(0.5)-squared fused with matmuls
- **Brotli + byte-shuffle compression** (PR #1089) — better compression of quantized weights
- **Bigram hash 5120**, VE dim 128, qk_gain 2.5
- **Eval**: sliding window, seq_len=6144, stride=128

## Artifact size (worst-case, seed 2)

| Component | Bytes |
|-----------|-------|
| Model (int6+brotli) | 15,692,661 |
| Code | 34,101 |
| **Total** | **15,726,762** |

Under the 16,000,000 byte limit.

## Acknowledgments

This submission builds on many contributions from the parameter-golf community:

- **Baseline** (modded-nanogpt, @KellerJordan et al.) — Muon optimizer, relu², U-Net skips, softcap, RoPE, Q-gain, ResidMix
- [modded-nanogpt PR #140](https://github.com/KellerJordan/modded-nanogpt/pull/140) (@ClassicLarry / @snimu) — backout residual
- **PR #50** (@mattqlf) — sliding window eval
- **PR #64** (@yesbhautik) — GPTQ-lite (clip percentile search)
- **PR #89 / #95** (@vmfunc / @MatoTeziTanka) — EMA / SWA
- **PR #162** (@raahilshah) — BigramHash
- **PR #287** (@jfprincz) — XSA (cross-head subtracted attention; [arXiv:2603.09078](https://arxiv.org/abs/2603.09078))
- **PR #315** (@jfprincz) — partial RoPE, layerwise LN scale
- **PR #374** (@unnir) — value embeddings
- **PR #399** (@abaybektursun) — parallel Muon with parameter banking
- **PR #493** (@parinzee) — LeakyReLU(0.5)²
- **PR #535** (@raahilshah) — full Hessian GPTQ
- **PR #549** (@abaybektursun) — banked weight matrices, SmearGate
- **PR #726** (@DeepReinforce) — coprime-stride multi-shard data loader
- **PR #1072** (@vimeto) — fused Triton LeakyReLU-squared MLP kernel
- **PR #1089** (@mikeapedia) — sigmoid-gated skip connections, byte-shuffle + brotli compression
- Flash Attention 3 with `window_size` for efficient window attention

## Reproducibility

The main training runs used the following command:

```bash
SEED=$SEED \
MATRIX_LR=0.024 MATRIX_LR_LATE=0.019 \
SCALAR_LR=0.020 SCALAR_LR_LATE=0.038 \
TIED_EMBED_LR=0.022 \
MUON_MOMENTUM=0.985 WARMDOWN_ITERS=4000 \
TRAIN_BATCH_TOKENS=589824 \
NUM_LAYERS=12 BIGRAM_VOCAB_SIZE=5120 VE_DIM=128 \
WINDOW_SIZE=512 WINDOW_ATTN_LAYERS=2,4,6,8,10 \
LOCAL_SEQS_PER_GPU=36,36,36,36,36,10,10,10 \
SEQS_PER_GPU=2048,2048,2048,2048,2048,6144,6144,6144 \
MAX_WALLCLOCK_SECONDS=600 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

`brotli` needs to be installed for the final artifact compression path. Flash Attention 3 (`flash_attn_interface`) is required.
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
W0331 16:22:12.901000 86416 torch/distributed/run.py:851]
W0331 16:22:12.901000 86416 torch/distributed/run.py:851] *****************************************
W0331 16:22:12.901000 86416 torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0331 16:22:12.901000 86416 torch/distributed/run.py:851] *****************************************
logs/973f1c0c-853b-47d1-ab98-f10d9f5c1909.txt
mixed_seq_len: GPU0=2048x36=73728tok GPU1=2048x36=73728tok GPU2=2048x36=73728tok GPU3=2048x36=73728tok GPU4=2048x36=73728tok GPU5=6144x10=61440tok GPU6=6144x10=61440tok GPU7=6144x10=61440tok total=552960
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62017536
model_params:29751910
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
XSA:last_7 active_layers:[5, 6, 7, 8, 9, 10, 11]
window_attn:size=512 layers:[2, 4, 6, 8, 10]
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.022 head_lr:0.0 matrix_lr:0.024 matrix_lr_late:0.019 scalar_lr:0.02 scalar_lr_late:0.038 leaky_slope:0.5
train_batch_tokens:589824 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
gptq:reserving 14000ms from training budget, effective=586000ms
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9292 val_bpb:4.1039 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9236 train_time:115ms step_avg:114.85ms
step:2/20000 train_loss:6.8930 train_time:163ms step_avg:81.41ms
step:3/20000 train_loss:6.5216 train_time:231ms step_avg:77.11ms
step:4/20000 train_loss:6.6048 train_time:300ms step_avg:74.96ms
step:5/20000 train_loss:6.6792 train_time:368ms step_avg:73.57ms
step:6/20000 train_loss:6.4394 train_time:436ms step_avg:72.66ms
step:7/20000 train_loss:6.1488 train_time:504ms step_avg:72.05ms
step:8/20000 train_loss:6.0559 train_time:573ms step_avg:71.57ms
step:9/20000 train_loss:5.8979 train_time:641ms step_avg:71.23ms
step:10/20000 train_loss:5.6586 train_time:709ms step_avg:70.91ms
step:500/20000 train_loss:2.2787 train_time:34572ms step_avg:69.14ms
step:1000/20000 train_loss:2.1944 train_time:69234ms step_avg:69.23ms
step:1500/20000 train_loss:2.1041 train_time:103918ms step_avg:69.28ms
step:2000/20000 train_loss:2.0576 train_time:138646ms step_avg:69.32ms
step:2500/20000 train_loss:1.9723 train_time:173384ms step_avg:69.35ms
step:3000/20000 train_loss:1.8794 train_time:208128ms step_avg:69.38ms
step:3500/20000 train_loss:1.8804 train_time:242869ms step_avg:69.39ms
step:4000/20000 train_loss:1.9143 train_time:277604ms step_avg:69.40ms
step:4000/20000 val_loss:2.0613 val_bpb:1.2208 train_time:277627ms step_avg:69.41ms
step:4500/20000 train_loss:1.8491 train_time:312344ms step_avg:69.41ms
step:5000/20000 train_loss:2.0239 train_time:347087ms step_avg:69.42ms
step:5500/20000 train_loss:1.9999 train_time:381836ms step_avg:69.42ms
step:6000/20000 train_loss:1.9671 train_time:416571ms step_avg:69.43ms
step:6500/20000 train_loss:1.9624 train_time:451297ms step_avg:69.43ms
step:7000/20000 train_loss:1.9521 train_time:486017ms step_avg:69.43ms
step:7500/20000 train_loss:1.9364 train_time:520744ms step_avg:69.43ms
swa:start step:7650
late_qat:enabled step:7836 scale:0.1499
step:8000/20000 train_loss:1.9445 train_time:555854ms step_avg:69.48ms
step:8000/20000 val_loss:1.9169 val_bpb:1.1353 train_time:555941ms step_avg:69.49ms
step:8428/20000 val_loss:1.9008 val_bpb:1.1258 train_time:586064ms step_avg:69.54ms
stopping_early: wallclock_cap train_time:586064ms step:8428/20000
peak memory allocated: 18548 MiB reserved: 18806 MiB
ema:applying EMA weights
DIAGNOSTIC post_ema val_loss:1.8996 val_bpb:1.1250 eval_time:2497ms
Serialized model: 116408386 bytes
Code size: 34101 bytes
gptq:building non-banked model for Hessian collection...
gptq:calibrating with 256 batches (train data)...
gptq:collected hessians for 74 layers (train data)
Serialized model int6+brotli: 15687597 bytes
Total submission size int6+brotli: 15721698 bytes
final_int6_roundtrip val_loss:1.8906 val_bpb:1.1197 eval_time:44320ms
final_int6_roundtrip_exact val_loss:1.89055057 val_bpb:1.11968671
final_int6_sliding_window val_loss:1.8742 val_bpb:1.1101 stride:128 eval_time:148032ms
final_int6_sliding_window_exact val_loss:1.87423223 val_bpb:1.11006071
87 changes: 87 additions & 0 deletions records/track_10min_16mb/2026-03-31_KitchenSinkV3/run_v3_seed2.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
W0331 17:28:58.848000 112208 torch/distributed/run.py:851]
W0331 17:28:58.848000 112208 torch/distributed/run.py:851] *****************************************
W0331 17:28:58.848000 112208 torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0331 17:28:58.848000 112208 torch/distributed/run.py:851] *****************************************
logs/cab54c4a-5bc2-49b1-a8e1-e5cda436759a.txt
mixed_seq_len: GPU0=2048x36=73728tok GPU1=2048x36=73728tok GPU2=2048x36=73728tok GPU3=2048x36=73728tok GPU4=2048x36=73728tok GPU5=6144x10=61440tok GPU6=6144x10=61440tok GPU7=6144x10=61440tok total=552960
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62017536
model_params:29751910
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
XSA:last_7 active_layers:[5, 6, 7, 8, 9, 10, 11]
window_attn:size=512 layers:[2, 4, 6, 8, 10]
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.022 head_lr:0.0 matrix_lr:0.024 matrix_lr_late:0.019 scalar_lr:0.02 scalar_lr_late:0.038 leaky_slope:0.5
train_batch_tokens:589824 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:2
gptq:reserving 14000ms from training budget, effective=586000ms
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9295 val_bpb:4.1040 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9234 train_time:117ms step_avg:116.89ms
step:2/20000 train_loss:6.8490 train_time:164ms step_avg:82.10ms
step:3/20000 train_loss:6.5213 train_time:232ms step_avg:77.25ms
step:4/20000 train_loss:6.5410 train_time:301ms step_avg:75.14ms
step:5/20000 train_loss:6.6556 train_time:369ms step_avg:73.79ms
step:6/20000 train_loss:6.5937 train_time:438ms step_avg:73.06ms
step:7/20000 train_loss:6.2665 train_time:507ms step_avg:72.41ms
step:8/20000 train_loss:6.0400 train_time:575ms step_avg:71.86ms
step:9/20000 train_loss:5.8710 train_time:644ms step_avg:71.52ms
step:10/20000 train_loss:5.6984 train_time:713ms step_avg:71.27ms
step:500/20000 train_loss:2.2753 train_time:34620ms step_avg:69.24ms
step:1000/20000 train_loss:2.1915 train_time:69284ms step_avg:69.28ms
step:1500/20000 train_loss:2.1015 train_time:103993ms step_avg:69.33ms
step:2000/20000 train_loss:2.0562 train_time:138747ms step_avg:69.37ms
step:2500/20000 train_loss:1.9619 train_time:173524ms step_avg:69.41ms
step:3000/20000 train_loss:1.8785 train_time:208302ms step_avg:69.43ms
step:3500/20000 train_loss:1.8760 train_time:243100ms step_avg:69.46ms
step:4000/20000 train_loss:1.9118 train_time:277882ms step_avg:69.47ms
step:4000/20000 val_loss:2.0583 val_bpb:1.2190 train_time:277906ms step_avg:69.48ms
step:4500/20000 train_loss:1.8503 train_time:312660ms step_avg:69.48ms
step:5000/20000 train_loss:2.0179 train_time:347440ms step_avg:69.49ms
step:5500/20000 train_loss:1.9952 train_time:382224ms step_avg:69.50ms
step:6000/20000 train_loss:1.9640 train_time:416999ms step_avg:69.50ms
step:6500/20000 train_loss:1.9601 train_time:451771ms step_avg:69.50ms
step:7000/20000 train_loss:1.9507 train_time:486532ms step_avg:69.50ms
step:7500/20000 train_loss:1.9296 train_time:521298ms step_avg:69.51ms
swa:start step:7650
late_qat:enabled step:7827 scale:0.1498
step:8000/20000 train_loss:1.9427 train_time:556434ms step_avg:69.55ms
step:8000/20000 val_loss:1.9151 val_bpb:1.1342 train_time:556516ms step_avg:69.56ms
step:8419/20000 val_loss:1.8995 val_bpb:1.1250 train_time:586049ms step_avg:69.61ms
stopping_early: wallclock_cap train_time:586049ms step:8419/20000
peak memory allocated: 18548 MiB reserved: 18800 MiB
ema:applying EMA weights
DIAGNOSTIC post_ema val_loss:1.8983 val_bpb:1.1243 eval_time:2503ms
Serialized model: 116408386 bytes
Code size: 34101 bytes
gptq:building non-banked model for Hessian collection...
gptq:calibrating with 256 batches (train data)...
gptq:collected hessians for 74 layers (train data)
Serialized model int6+brotli: 15692661 bytes
Total submission size int6+brotli: 15726762 bytes
final_int6_roundtrip val_loss:1.8894 val_bpb:1.1190 eval_time:7361ms
final_int6_roundtrip_exact val_loss:1.88938736 val_bpb:1.11899779
final_int6_sliding_window val_loss:1.8731 val_bpb:1.1094 stride:128 eval_time:123593ms
final_int6_sliding_window_exact val_loss:1.87306252 val_bpb:1.10936792
Loading