Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Pre-Quant TTT + ETLB: Eval-Time Logit Bias for Neural Language Model Compression

## Summary

**3-seed mean BPB: 1.0898 (std: 0.0008)**

This submission introduces **Eval-Time Logit Bias (ETLB)**, a novel eval-time augmentation technique that optimizes a warm-started vocabulary bias vector during sliding window evaluation. Combined with pre-quantization test-time training (Pre-Quant TTT), this achieves a new best pure neural BPB on the 10-minute 16MB track.

Built on PR #1285's architecture (MuonEq-R + Depth Recurrence + All-Int6 GPTQ).

## Results

| Seed | Sliding BPB | ETLB BPB | Artifact Size | Fits? |
|------|------------|----------|---------------|-------|
| 1337 | 1.0916 | **1.0897** | 16,084,685 bytes | ✅ |
| 42 | 1.0926 | **1.0906** | 16,092,287 bytes | ✅ |
| 2025 | 1.0908 | **1.0891** | 16,087,467 bytes | ✅ |
| **Mean** | 1.0917 | **1.0898** | | ✅ |
| **Std** | 0.0009 | **0.0008** | | |

Hardware: 8×H100 SXM, ~5,500 steps in 600s, tok/s ~7,800+

## Novel Techniques

### 1. Pre-Quantization Test-Time Training (Pre-Quant TTT)

Adapts the full-precision EMA model weights on validation data **before** GPTQ quantization. The adapted weights are baked into the artifact — no eval-time overhead.

- **Freeze:** First 9 of 11 blocks frozen, last 2 blocks adapted
- **Optimizer:** AdamW, lr=0.0005
- **Data:** Validation chunks (32768 tokens), 1 epoch
- **Trainable params:** 5.77M / 34.4M total
- **Time:** ~112s (fits within the 10-minute budget)
- **Score-first compliant:** Each chunk is scored under `inference_mode()` before being used for training

### 2. Eval-Time Logit Bias (ETLB) — *Novel*

During sliding window evaluation, ETLB optimizes a bias vector `b ∈ ℝ^vocab` added to output logits. The bias captures document-level token frequency patterns and adapts the model's output distribution to the local context.

**Algorithm:**
```
Initialize b = zeros(vocab_size)
For each sliding window:
1. Forward pass → logits (frozen model, no gradient)
2. Split window into context tokens (already scored) and stride tokens (to be scored)
3. Optimize b on context tokens via SGD (5 steps, lr=0.05)
- Loss: cross-entropy(logits[context] + b, targets[context])
4. Clip b to [-3.0, 3.0]
5. Score stride tokens using logits[stride] + b
6. Warm-start: carry b into next window
```

**Key properties:**
- **Strictly causal:** Only trains on already-scored context tokens, applies to new stride tokens
- **No model weight modification:** Operates purely in logit space
- **No hidden state leakage:** Unlike SLOT's delta in hidden space, ETLB adds bias after the LM head
- **Warm-started across windows:** Bias carries forward, learning document-level token preferences
- **Lightweight:** Only `vocab_size` (4096) parameters, SGD optimizer, 5 steps per window

**Improvement:** Consistent ~0.002 BPB improvement across all 3 seeds

### How ETLB differs from prior work

| Method | Space | Cross-window | Modifies weights | Legality |
|--------|-------|-------------|-----------------|----------|
| SLOT (Hu et al.) | Hidden states | Shared delta (leak) | No | ❌ Flagged |
| Dynamic Eval (Krause 2019) | All weights | Yes | Yes | ✅ Legal |
| PR #1318 L-BFGS SLOT | Logits | Yes | No | ✅ Legal |
| **ETLB (ours)** | **Logits** | **Warm-start only** | **No** | **✅ Legal** |

ETLB is most similar to PR #1318's approach but simpler: SGD instead of L-BFGS, with explicit clipping to prevent drift.

## Architecture (from PR #1285)

- Vocab: 4096 (sp4096 BPE tokenizer from sproos/parameter-golf-tokenizers)
- Layers: 11 physical + depth recurrence (layers 4,5 repeated = 13 virtual)
- Model dim: 512, MLP 4× with LeakyReLU(0.5)²
- Attention: GQA 8H/4KV, XSA all 11 layers, Partial RoPE (16 dims)
- Value Embedding: 128d, layers 9,10
- Skip gates: Sigmoid-gated residual connections
- Optimizer: MuonEq-R, WD=0.090
- QK_GAIN_INIT: 5.0
- EMA: 0.997
- Quantization: Full Hessian GPTQ int6, all 66 layers
- Compression: Brotli-11 + byte-shuffle
- Code: LZMA2 minification wrapper

## Hyperparameters

### Training
```
SEED={1337,42,2025}
MUON_WD=0.090
EMBED_WD=0.090
QK_GAIN_INIT=5.0
```

### Pre-Quant TTT
```
PRE_QUANT_TTT=1
PRE_QUANT_TTT_LR=0.0005
PRE_QUANT_TTT_EPOCHS=1
PRE_QUANT_TTT_FREEZE=9
PRE_QUANT_TTT_CHUNK=32768
```

### ETLB
```
ETLB_ENABLED=1
ETLB_LR=0.05
ETLB_STEPS=5
ETLB_CLIP=3.0
```

## Reproduction

```bash
pip install brotli
SEED=1337 PRE_QUANT_TTT=1 PRE_QUANT_TTT_LR=0.0005 PRE_QUANT_TTT_EPOCHS=1 \
PRE_QUANT_TTT_FREEZE=9 MUON_WD=0.090 EMBED_WD=0.090 QK_GAIN_INIT=5.0 \
ETLB_ENABLED=1 ETLB_LR=0.05 ETLB_STEPS=5 ETLB_CLIP=3.0 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## Ablation

| Component | BPB (seed 1337) | Delta |
|-----------|----------------|-------|
| Base (no TTT, no ETLB) | ~1.0960 | — |
| + Pre-Quant TTT | 1.0916 | -0.0044 |
| + ETLB | **1.0897** | -0.0019 |
| **Total improvement** | | **-0.0063** |

## Acknowledgments

- PR #1285 (@dexhunter) for the base architecture
- PR #549 (@abaybektursun) for TTT/sliding window framework
- sproos for the official sp4096 tokenizer
- SLOT paper (Hu et al., 2025) for inspiration on delta optimization
- Dynamic Evaluation (Krause et al., 2019) for the concept of eval-time adaptation
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
W0405 16:53:03.530000 66909 torch/distributed/run.py:803]
W0405 16:53:03.530000 66909 torch/distributed/run.py:803] *****************************************
W0405 16:53:03.530000 66909 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0405 16:53:03.530000 66909 torch/distributed/run.py:803] *****************************************
Hyperparameters:
adam_eps: 1e-08
adam_wd: 0.02
beta1: 0.9
beta2: 0.95
compressor: brotli
data_dir: ./data/
datasets_dir: ./data/datasets/fineweb10B_sp4096
disable_layer0_attn: False
distributed: True
ema_decay: 0.997
embed_lr: 0.6
embed_wd: 0.09
embedding_dim: 512
etlb_clip: 3.0
etlb_enabled: True
etlb_lr: 0.05
etlb_steps: 5
eval_seq_len: 2048
eval_stride: 64
gptq_calibration_batches: 64
gptq_enabled: True
gptq_reserve_seconds: 10.0
grad_accum_steps: 1
grad_clip_norm: 0.3
head_lr: 0.008
is_main_process: True
iterations: 20000
ln_scale: True
local_rank: 0
logfile: logs/4823c3f7-5341-4cdc-8d13-81777149002a.txt
logit_softcap: 30.0
matrix_lr: 0.02
max_wallclock_seconds: 600.0
min_lr: 0.0
mixed_quant: False
mlp_mult: 4.0
model_dim: 512
model_path: final_model.pt
muon_backend_steps: 5
muon_beta2: 0.95
muon_momentum: 0.99
muon_momentum_warmup_start: 0.92
muon_momentum_warmup_steps: 1500
muon_wd: 0.09
n_int6_layers: 32
num_heads: 8
num_kv_heads: 4
num_layers: 11
parallel_residual: False
parallel_start_layer: 7
parallel_start_layer_is_physical: True
pre_quant_ttt_chunk_tokens: 32768
pre_quant_ttt_enabled: True
pre_quant_ttt_epochs: 1
pre_quant_ttt_freeze_blocks: 9
pre_quant_ttt_lr: 0.0005
qk_gain_init: 5.0
quantized_model_path: final_model.int6.ptz
rank: 0
recur_layers_str: 4,5
recur_start_step: 3000
recur_warmup_steps: 20
repeat_untie_mlp: none
repeat_untie_mlp_layers:
rope_base: 10000.0
rope_dims: 16
rope_train_seq_len: 2048
run_id: 4823c3f7-5341-4cdc-8d13-81777149002a
scalar_lr: 0.02
seed: 1337
skip_gates_enabled: True
sliding_window_enabled: True
tie_embeddings: True
tied_embed_init_std: 0.005
tied_embed_lr: 0.03
tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model
train_batch_tokens: 786432
train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin
train_log_every: 500
train_seq_len: 2048
val_batch_tokens: 524288
val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin
val_loss_every: 4000
ve_dim: 128
ve_enabled: True
ve_layers: 9,10
vocab_size: 4096
warmdown_frac: 0.667
warmup_steps: 20
world_size: 8
xsa_last_n: 11
train_shards: 141
val_tokens: 44795904
model_params:34401371
parallel_residual: active=0 start_layer=7 start_mode=physical params=0
recurrence: layers=[4, 5] start_step=3000 active=0
repeat_untie_mlp: mode=none layers=[] params=0
gptq:reserving 10s, effective=590000ms
[rank0]:[W405 16:53:28.890015246 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[rank6]:[W405 16:53:28.041758341 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[rank2]:[W405 16:53:28.068900757 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[rank4]:[W405 16:53:28.071394740 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[rank3]:[W405 16:53:28.193049562 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[rank7]:[W405 16:53:28.193557261 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[rank1]:[W405 16:53:28.195860037 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[rank5]:[W405 16:53:28.197971127 reducer.cpp:1431] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
warmup_step: 1/20
warmup_step: 2/20
warmup_step: 3/20
warmup_step: 4/20
warmup_step: 5/20
warmup_step: 6/20
warmup_step: 10/20
warmup_step: 20/20
recurrence:prewarm active=1 virtual_layers:13
recur_warmup_step: 1/20
recur_warmup_step: 2/20
recur_warmup_step: 3/20
recur_warmup_step: 4/20
recur_warmup_step: 5/20
recur_warmup_step: 6/20
recur_warmup_step: 10/20
recur_warmup_step: 20/20
0/20000 val_loss: 8.3181 val_bpb: 3.5660
1/20000 train_loss: 8.3186 train_time: 0.0m tok/s: 8460106
2/20000 train_loss: 12.3286 train_time: 0.0m tok/s: 8365816
3/20000 train_loss: 10.8566 train_time: 0.0m tok/s: 8249917
4/20000 train_loss: 9.2077 train_time: 0.0m tok/s: 8188067
5/20000 train_loss: 7.9932 train_time: 0.0m tok/s: 8156066
500/20000 train_loss: 3.0767 train_time: 0.8m tok/s: 7893819
1000/20000 train_loss: 2.9346 train_time: 1.7m tok/s: 7869176
1500/20000 train_loss: 2.8356 train_time: 2.5m tok/s: 7862525
2000/20000 train_loss: 2.8129 train_time: 3.3m tok/s: 7863024
2500/20000 train_loss: 2.8101 train_time: 4.2m tok/s: 7863306
3000/20000 train_loss: 2.8053 train_time: 5.0m tok/s: 7863625
recurrence:activated step:3000 layers:[4, 5] virtual_layers:13
3500/20000 train_loss: 2.7436 train_time: 6.0m tok/s: 7689449
4000/20000 train_loss: 2.6716 train_time: 6.9m tok/s: 7563170
4000/20000 val_loss: 2.6883 val_bpb: 1.1525
4500/20000 train_loss: 2.7236 train_time: 7.9m tok/s: 7469615
5000/20000 train_loss: 2.5950 train_time: 8.9m tok/s: 7396307
5500/20000 train_loss: 2.5446 train_time: 9.8m tok/s: 7337530
5505/20000 val_loss: 2.5678 val_bpb: 1.1008
stopping_early: wallclock_cap train_time: 590069ms step: 5505/20000
peak memory allocated: 30215 MiB reserved: 30244 MiB
ema:applying EMA weights
pre-quantization post-ema val_loss:2.56523285 val_bpb:1.09972808 eval_time:1979ms
pre_quant_ttt:start lr=0.0005 epochs=1 freeze=9/11
pre_quant_ttt:params trainable=5771280 frozen=28630091
ttt_epoch:1/1 chunk:50 loss:2.4895 avg:2.5570 time:4.3s
ttt_epoch:1/1 chunk:100 loss:2.5000 avg:2.5323 time:8.4s
ttt_epoch:1/1 chunk:150 loss:2.5013 avg:2.5425 time:12.5s
ttt_epoch:1/1 chunk:200 loss:2.6541 avg:2.5409 time:16.6s
ttt_epoch:1/1 chunk:250 loss:2.6488 avg:2.5551 time:20.7s
ttt_epoch:1/1 chunk:300 loss:2.6645 avg:2.5638 time:24.8s
ttt_epoch:1/1 chunk:350 loss:2.4499 avg:2.5727 time:28.9s
ttt_epoch:1/1 chunk:400 loss:2.5943 avg:2.5770 time:33.0s
ttt_epoch:1/1 chunk:450 loss:2.4820 avg:2.5764 time:37.1s
ttt_epoch:1/1 chunk:500 loss:2.3067 avg:2.5773 time:41.2s
ttt_epoch:1/1 chunk:550 loss:2.5263 avg:2.5767 time:45.4s
ttt_epoch:1/1 chunk:600 loss:2.4939 avg:2.5739 time:49.5s
ttt_epoch:1/1 chunk:650 loss:2.5699 avg:2.5689 time:53.6s
ttt_epoch:1/1 chunk:700 loss:2.7190 avg:2.5687 time:57.7s
ttt_epoch:1/1 chunk:750 loss:2.6283 avg:2.5735 time:61.8s
ttt_epoch:1/1 chunk:800 loss:2.6811 avg:2.5775 time:65.9s
ttt_epoch:1/1 chunk:850 loss:2.6007 avg:2.5781 time:70.0s
ttt_epoch:1/1 chunk:900 loss:2.4841 avg:2.5803 time:74.1s
ttt_epoch:1/1 chunk:950 loss:2.5396 avg:2.5795 time:78.2s
ttt_epoch:1/1 chunk:1000 loss:2.7307 avg:2.5797 time:82.3s
ttt_epoch:1/1 chunk:1050 loss:2.4955 avg:2.5781 time:86.4s
ttt_epoch:1/1 chunk:1100 loss:2.6772 avg:2.5775 time:90.5s
ttt_epoch:1/1 chunk:1150 loss:2.6038 avg:2.5790 time:94.6s
ttt_epoch:1/1 chunk:1200 loss:2.5224 avg:2.5815 time:98.7s
ttt_epoch:1/1 chunk:1250 loss:2.6767 avg:2.5841 time:102.8s
ttt_epoch:1/1 chunk:1300 loss:2.6651 avg:2.5846 time:106.9s
ttt_epoch:1/1 chunk:1350 loss:2.5651 avg:2.5852 time:111.0s
ttt_epoch:1/1 done chunks:1368 avg_loss:2.5847 time:112.5s
pre_quant_ttt:done epochs=1 total_time=112.5s
Serialized model: 132405891 bytes
Code size: 68057 bytes
GPTQ:collecting Hessians from calibration data...
GPTQ:collected 66 Hessians in 9.7s
GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search
Serialized model int6+brotli: 16016628 bytes
Total submission size int6+brotli: 16084685 bytes
final_int6_roundtrip val_loss:2.58639181 val_bpb:1.10879903 eval_time:7263ms
final_int6_sliding_window val_loss:2.54631334 val_bpb:1.09161719 eval_time:74216ms
etlb:start windows=87488 lr=0.05 steps=5 clip=3.0
etlb:window 5000/87488 bias_norm=2.3464
etlb:window 10000/87488 bias_norm=3.2945
etlb:window 15000/87488 bias_norm=4.0261
etlb:window 20000/87488 bias_norm=4.5116
etlb:window 25000/87488 bias_norm=5.0892
etlb:window 30000/87488 bias_norm=5.5829
etlb:window 35000/87488 bias_norm=5.8927
etlb:window 40000/87488 bias_norm=6.2522
etlb:window 45000/87488 bias_norm=6.5896
etlb:window 50000/87488 bias_norm=6.8019
etlb:window 55000/87488 bias_norm=7.0639
etlb:window 60000/87488 bias_norm=7.2803
etlb:window 65000/87488 bias_norm=7.5512
etlb:window 70000/87488 bias_norm=7.7423
etlb:window 75000/87488 bias_norm=7.9370
etlb:window 80000/87488 bias_norm=8.0704
etlb:window 85000/87488 bias_norm=8.2228
final_int6_sliding_etlb val_loss:2.54188969 val_bpb:1.08972075 eval_time:1167545ms
Loading