Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Record: SP8192 + Score-First TTT + Eval-Time Hash Embedding

**val_bpb: 1.08269** (3-seed mean, std 0.00060) | ~15.99 MB | 8xH100 SXM | ~450s eval

Merged SOTA (PR #1019, 3-seed mean): **1.88218 nats**. This run: **2.79670 nats**. Delta: **-0.914 nats**. Clears the 0.005-nat threshold.

## Results (3-seed)

| Seed | BPP | val_loss (nats) | Artifact |
|------|-----|-----------------|----------|
| 1337 | **1.08218** | 2.79537 | 15,982,929 |
| 42 | **1.08252** | 2.79626 | 15,988,459 |
| 2025 | **1.08337** | 2.79846 | 15,989,420 |
| **Mean** | **1.08269** | **2.79670** | |

## Changes from Merged SOTA (PR #1019)

### 1. Eval-Time Hash Embedding (Novel)

A zero-initialized `nn.Embedding(16384, 512)` is created at evaluation time and trained exclusively through the score-first TTT loop. At each position, a bigram hash `h = (prev_token * 2039 + curr_token) % 16384` looks up a residual vector that is added to `tok_emb(x)` before RMSNorm. The hash embedding learns document-local bigram patterns without modifying any pre-trained model weights.

**Nearest PR:** PR #1413 (@kevclark) — legal score-first TTT with full-model weight updates. **Different:** We add an ephemeral hash embedding that is instantiated from zeros at eval start and adapts via the same TTT loop. This is a new adaptation target — the model retunes a separate bigram-keyed memory alongside its existing weights. No existing PR creates and trains a new embedding module from scratch at eval time (LoRA-TTT PRs #1254/#1354 create adapter matrices, but those adapt existing layers, not a standalone hash embedding).

**Measured delta:** -0.0004 BPP vs packed baseline without hash embedding (ablation: 1.08307 mean without, 1.08269 mean with).

### 2. Score-First TTT (Legal)

SGD with momentum 0.9, LR=0.005, 3 epochs per 32K-token chunk, cosine decay. All model blocks unfrozen (freeze=0). Same mechanism as PR #549 and PR #1413.

**Measured delta:** -0.002 BPP vs sliding window without TTT.

### 3. SP8192 Architecture Stack

- 11 layers, model_dim=512, 8 heads, 4 KV heads
- Parallel residuals (layers 7-10, PaLM-style)
- Depth recurrence (layers 4-5, loop 2x)
- Skip gates (sigmoid-gated skip connections)
- QK-Gain 4.0, XSA (all 11 layers)
- Full Hessian GPTQ int6 + byte-shuffle + brotli compression
- Coprime-stride weighted multi-shard data loader
- Code packed with lzma+base85 self-extracting wrapper (saves 32KB)

## Compliance

Per Issue #1017 (Track B — legal eval-time adaptation):

- **Condition 1 (Causal/prefix-only):** Hash key uses `(prev_token, curr_token)` — both are input token identities from `x_batch = chunk[:-1]`, not model predictions. The hash embedding at position t depends only on prefix tokens.
- **Condition 2 (Full normalized distribution):** Hash residual is added to the embedding before RMSNorm and the standard transformer + tied LM head + full-vocab softmax.
- **Condition 3 (Score-before-update):** Each chunk is fully scored under `torch.no_grad()` before any TTT parameter update. The hash embedding is updated as part of the standard TTT training step, after scoring.
- **Condition 4 (Single left-to-right pass):** One evaluation pass, no rescoring, no multi-pass selection.
- **Precedent for eval-time-created parameters:** LoRA-TTT PRs #1254, #1354 also instantiate new trainable parameters at eval time.

No SLOT, no pre-quant TTT, no n-gram caches, no ETLB.

## Reproduction

```bash
pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

No env vars needed. All defaults are the submission config.

## Credits

- Base architecture: PR #549 (@abaybektursun), PR #1019 (@abaybektursun)
- Score-first TTT framework: PR #549 (@abaybektursun), PR #1413 (@kevclark)
- Parallel residuals + depth recurrence: PR #1204 (@msisovic)
- SP8192 + GPTQ embeddings + SDClip: PR #1394 (@clarkkev)
- Coprime-stride loader: PR #726, PR #1060
- Eval-time hash embedding: original to this submission
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
flash_attn_3
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"description": "SP8192 + Score-First TTT + Eval-Time Hash Embedding",
"val_bpb_mean": 1.08269,
"val_bpb_std": 0.00060,
"seeds": [1337, 42, 2025],
"hardware": "8xH100 SXM",
"framework": "PyTorch 2.9.1+cu128"
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
From https://github.com/resouer/parameter-golf
* branch exp/round-14/packed-baseline -> FETCH_HEAD
Note: switching to 'f8889961853802ce9beae36ae6453f4754fd71ab'.
You are in 'detached HEAD' state. You can look around, make experimental
changes and commit them, and you can discard any commits you make in this
state without impacting any branches by switching back to a branch.
If you want to create a new branch to retain commits you create, you may
do so (now or later) by using -c with the switch command. Example:
git switch -c <new-branch-name>
Or undo this operation with:
git switch -
Turn off this advice by setting config variable advice.detachedHead to false
HEAD is now at f888996 Tune hash embedding: 16K buckets + 10x LR
data_setup: vocab=8192 shards=128
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
W0407 17:26:19.904000 230 torch/distributed/run.py:803]
W0407 17:26:19.904000 230 torch/distributed/run.py:803] *****************************************
W0407 17:26:19.904000 230 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
Hyperparameters:
adam_eps: 1e-08
adam_wd: 0.02
beta1: 0.9
beta2: 0.95
compressor: brotli
data_dir: ./data/
datasets_dir: ./data/datasets/fineweb10B_sp8192
distributed: True
ema_decay: 0.997
embed_bits: 8
embed_clip_sigmas: 20.0
embed_lr: 0.6
embed_wd: 0.085
embedding_dim: 512
enable_looping_at: 0.5
eval_seq_len: 2048
eval_stride: 64
gptq_calibration_batches: 64
gptq_reserve_seconds: 12.0
grad_accum_steps: 1
grad_clip_norm: 0.3
head_lr: 0.008
is_main_process: True
iterations: 20000
ln_scale: True
local_rank: 0
logfile: logs/0955bf9f-30f6-456c-9a1b-1bba772c0180.txt
logit_softcap: 30.0
loop_end: 5
loop_start: 4
matrix_bits: 6
matrix_clip_sigmas: 12.85
matrix_lr: 0.02
max_wallclock_seconds: 600.0
min_lr: 0.0
mlp_mult: 4.0
model_dim: 512
model_path: final_model.pt
muon_backend_steps: 5
muon_beta2: 0.95
muon_momentum: 0.99
muon_momentum_warmup_start: 0.92
muon_momentum_warmup_steps: 1500
muon_row_normalize: True
muon_wd: 0.085
num_heads: 8
num_kv_heads: 4
num_layers: 11
num_loops: 2
parallel_start_layer: 7
qk_gain_init: 4.0
quantized_model_path: final_model.int6.ptz
rank: 0
rope_base: 10000.0
rope_dims: 16
rope_train_seq_len: 2048
run_id: 0955bf9f-30f6-456c-9a1b-1bba772c0180
scalar_lr: 0.02
seed: 1337
skip_gates_enabled: True
sliding_window_enabled: True
tie_embeddings: True
tied_embed_init_std: 0.005
tied_embed_lr: 0.03
tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model
train_batch_tokens: 786432
train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin
train_log_every: 500
train_seq_len: 2048
ttt_batch_seqs: 32
ttt_chunk_tokens: 32768
ttt_enabled: True
ttt_epochs: 3
ttt_freeze_blocks: 0
ttt_grad_clip: 1.0
ttt_loop_only: False
ttt_lr: 0.005
ttt_momentum: 0.9
val_batch_tokens: 524288
val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin
val_loss_every: 4000
vocab_size: 8192
warmdown_frac: 0.667
warmup_steps: 20
world_size: 8
xsa_last_n: 11
train_shards: 128
val_tokens: 40540160
model_params:35943512
gptq:reserving 12s, effective=588000ms
warmup_step: 1/20
warmup_step: 2/20
warmup_step: 3/20
warmup_step: 4/20
warmup_step: 5/20
warmup_step: 6/20
warmup_step: 10/20
warmup_step: 20/20
loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 4] decoder:[5, 4, 5, 6, 7, 8, 9, 10]
loop_warmup_step: 1/20
loop_warmup_step: 2/20
loop_warmup_step: 3/20
loop_warmup_step: 4/20
loop_warmup_step: 5/20
loop_warmup_step: 6/20
loop_warmup_step: 10/20
loop_warmup_step: 20/20
0/20000 val_loss: 9.0047 val_bpb: 3.4860
1/20000 train_loss: 9.0084 train_time: 0.0m tok/s: 7979883
2/20000 train_loss: 12.3253 train_time: 0.0m tok/s: 7956964
3/20000 train_loss: 11.0666 train_time: 0.0m tok/s: 7894778
4/20000 train_loss: 9.4821 train_time: 0.0m tok/s: 7862760
5/20000 train_loss: 8.4564 train_time: 0.0m tok/s: 7840139
500/20000 train_loss: 3.4002 train_time: 0.9m tok/s: 7642136
1000/20000 train_loss: 3.2103 train_time: 1.7m tok/s: 7638780
1500/20000 train_loss: 3.2043 train_time: 2.6m tok/s: 7640151
2000/20000 train_loss: 3.1272 train_time: 3.4m tok/s: 7644863
2500/20000 train_loss: 2.9939 train_time: 4.3m tok/s: 7646757
layer_loop:enabled step:2858 frac:0.500 encoder:[0, 1, 2, 3, 4, 5, 4] decoder:[5, 4, 5, 6, 7, 8, 9, 10]
3000/20000 train_loss: 2.9880 train_time: 5.2m tok/s: 7534801
3500/20000 train_loss: 3.0117 train_time: 6.4m tok/s: 7212071
4000/20000 train_loss: 2.9522 train_time: 7.5m tok/s: 7004380
4000/20000 val_loss: 2.9168 val_bpb: 1.1292
4500/20000 train_loss: 2.9283 train_time: 8.6m tok/s: 6851646
5000/20000 train_loss: 2.9118 train_time: 9.7m tok/s: 6733583
5031/20000 val_loss: 2.8147 val_bpb: 1.0897
stopping_early: wallclock_cap train_time: 588136ms step: 5031/20000
peak memory allocated: 34604 MiB reserved: 34708 MiB
ema:applying EMA weights
pre-quantization post-ema val_loss:2.81221955 val_bpb:1.08869783 eval_time:7242ms
Serialized model: 135426937 bytes
Code size: 17405 bytes
GPTQ:collecting Hessians from calibration data...
GPTQ:collected 67 Hessians in 11.4s
Quantized weights:
gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight
gptq (int8): tok_emb.weight
passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights
Serialized model quantized+brotli: 15965524 bytes
Total submission size quantized+brotli: 15982929 bytes
final_int6_roundtrip_exact val_loss:2.84283616 val_bpb:1.10055047 eval_time:28254ms
final_int6_sliding_window val_loss:2.79967517 val_bpb:1.08384151 eval_time:116330ms
eval_hash_emb:init size=16384 dim=512 lr_mult=10x
ttt_sliding:start chunks=1238 chunk_tokens=32768 total_windows=633409 stride=64 ttt_lr=0.005 ttt_epochs=3 freeze_blocks=0 loop_only=False
ttt_sliding:params model=35943512 hash=8388608 frozen=0
ttt_chunk [1/1238] bpb=1.118075 time=39.3s
ttt_chunk [11/1238] bpb=1.072418 time=60.4s
ttt_chunk [21/1238] bpb=1.110299 time=63.4s
ttt_chunk [31/1238] bpb=1.104399 time=66.4s
ttt_chunk [41/1238] bpb=1.097277 time=69.4s
ttt_chunk [51/1238] bpb=1.090595 time=72.5s
ttt_chunk [61/1238] bpb=1.082098 time=75.5s
ttt_chunk [71/1238] bpb=1.089282 time=78.5s
ttt_chunk [81/1238] bpb=1.082622 time=81.5s
ttt_chunk [91/1238] bpb=1.079137 time=84.5s
ttt_chunk [101/1238] bpb=1.079003 time=87.5s
ttt_chunk [111/1238] bpb=1.077055 time=90.5s
ttt_chunk [121/1238] bpb=1.079954 time=93.5s
ttt_chunk [131/1238] bpb=1.083717 time=96.5s
ttt_chunk [141/1238] bpb=1.084246 time=99.5s
ttt_chunk [151/1238] bpb=1.084058 time=102.5s
ttt_chunk [161/1238] bpb=1.084570 time=105.5s
ttt_chunk [171/1238] bpb=1.084351 time=108.5s
ttt_chunk [181/1238] bpb=1.082837 time=111.5s
ttt_chunk [191/1238] bpb=1.082471 time=114.6s
ttt_chunk [201/1238] bpb=1.080035 time=117.6s
ttt_chunk [211/1238] bpb=1.084420 time=120.6s
ttt_chunk [221/1238] bpb=1.084683 time=123.6s
ttt_chunk [231/1238] bpb=1.086376 time=126.7s
ttt_chunk [241/1238] bpb=1.084805 time=129.7s
ttt_chunk [251/1238] bpb=1.084906 time=132.7s
ttt_chunk [261/1238] bpb=1.085894 time=135.8s
ttt_chunk [271/1238] bpb=1.086360 time=138.8s
ttt_chunk [281/1238] bpb=1.085676 time=141.8s
ttt_chunk [291/1238] bpb=1.086788 time=144.9s
ttt_chunk [301/1238] bpb=1.086983 time=147.9s
ttt_chunk [311/1238] bpb=1.085810 time=150.9s
ttt_chunk [321/1238] bpb=1.085742 time=153.9s
ttt_chunk [331/1238] bpb=1.086015 time=156.9s
ttt_chunk [341/1238] bpb=1.085132 time=159.9s
ttt_chunk [351/1238] bpb=1.085876 time=162.9s
ttt_chunk [361/1238] bpb=1.084845 time=165.9s
ttt_chunk [371/1238] bpb=1.083294 time=169.0s
ttt_chunk [381/1238] bpb=1.083703 time=172.0s
ttt_chunk [391/1238] bpb=1.083389 time=175.0s
ttt_chunk [401/1238] bpb=1.083470 time=178.0s
ttt_chunk [411/1238] bpb=1.084060 time=181.0s
ttt_chunk [421/1238] bpb=1.083528 time=184.0s
ttt_chunk [431/1238] bpb=1.083680 time=187.0s
ttt_chunk [441/1238] bpb=1.083740 time=190.0s
ttt_chunk [451/1238] bpb=1.084933 time=193.0s
ttt_chunk [461/1238] bpb=1.083184 time=196.0s
ttt_chunk [471/1238] bpb=1.083202 time=199.0s
ttt_chunk [481/1238] bpb=1.083356 time=202.1s
ttt_chunk [491/1238] bpb=1.083813 time=205.1s
ttt_chunk [501/1238] bpb=1.083402 time=208.1s
ttt_chunk [511/1238] bpb=1.083042 time=211.1s
ttt_chunk [521/1238] bpb=1.082550 time=214.1s
ttt_chunk [531/1238] bpb=1.082540 time=217.1s
ttt_chunk [541/1238] bpb=1.082638 time=220.1s
ttt_chunk [551/1238] bpb=1.082180 time=223.1s
ttt_chunk [561/1238] bpb=1.081472 time=226.1s
ttt_chunk [571/1238] bpb=1.080947 time=229.1s
ttt_chunk [581/1238] bpb=1.081304 time=232.1s
ttt_chunk [591/1238] bpb=1.081547 time=235.2s
ttt_chunk [601/1238] bpb=1.081472 time=238.2s
ttt_chunk [611/1238] bpb=1.082065 time=241.2s
ttt_chunk [621/1238] bpb=1.082927 time=244.2s
ttt_chunk [631/1238] bpb=1.082988 time=247.3s
ttt_chunk [641/1238] bpb=1.083444 time=250.3s
ttt_chunk [651/1238] bpb=1.083771 time=253.3s
ttt_chunk [661/1238] bpb=1.083074 time=256.3s
ttt_chunk [671/1238] bpb=1.082817 time=259.3s
ttt_chunk [681/1238] bpb=1.084133 time=262.3s
ttt_chunk [691/1238] bpb=1.084339 time=265.4s
ttt_chunk [701/1238] bpb=1.084158 time=268.4s
ttt_chunk [711/1238] bpb=1.084863 time=271.4s
ttt_chunk [721/1238] bpb=1.085172 time=274.4s
ttt_chunk [731/1238] bpb=1.084527 time=277.4s
ttt_chunk [741/1238] bpb=1.084264 time=280.4s
ttt_chunk [751/1238] bpb=1.083327 time=283.4s
ttt_chunk [761/1238] bpb=1.082729 time=286.4s
ttt_chunk [771/1238] bpb=1.081737 time=289.4s
ttt_chunk [781/1238] bpb=1.081705 time=292.4s
ttt_chunk [791/1238] bpb=1.082077 time=295.4s
ttt_chunk [801/1238] bpb=1.082371 time=298.4s
ttt_chunk [811/1238] bpb=1.081889 time=301.4s
ttt_chunk [821/1238] bpb=1.080715 time=304.4s
ttt_chunk [831/1238] bpb=1.080390 time=307.4s
ttt_chunk [841/1238] bpb=1.079961 time=310.4s
ttt_chunk [851/1238] bpb=1.079671 time=313.4s
ttt_chunk [861/1238] bpb=1.079308 time=316.4s
ttt_chunk [871/1238] bpb=1.079154 time=319.4s
ttt_chunk [881/1238] bpb=1.078678 time=322.4s
ttt_chunk [891/1238] bpb=1.078162 time=325.4s
ttt_chunk [901/1238] bpb=1.078537 time=328.4s
ttt_chunk [911/1238] bpb=1.078240 time=331.4s
ttt_chunk [921/1238] bpb=1.078529 time=334.4s
ttt_chunk [931/1238] bpb=1.079226 time=337.4s
ttt_chunk [941/1238] bpb=1.079620 time=340.4s
ttt_chunk [951/1238] bpb=1.079547 time=343.4s
ttt_chunk [961/1238] bpb=1.080376 time=346.3s
ttt_chunk [971/1238] bpb=1.080783 time=349.3s
ttt_chunk [981/1238] bpb=1.081129 time=352.3s
ttt_chunk [991/1238] bpb=1.080920 time=355.2s
ttt_chunk [1001/1238] bpb=1.080958 time=358.2s
ttt_chunk [1011/1238] bpb=1.081282 time=361.1s
ttt_chunk [1021/1238] bpb=1.081971 time=364.1s
ttt_chunk [1031/1238] bpb=1.082455 time=367.1s
ttt_chunk [1041/1238] bpb=1.082935 time=370.0s
ttt_chunk [1051/1238] bpb=1.082857 time=373.0s
ttt_chunk [1061/1238] bpb=1.082862 time=375.9s
ttt_chunk [1071/1238] bpb=1.083027 time=378.9s
ttt_chunk [1081/1238] bpb=1.082909 time=381.9s
ttt_chunk [1091/1238] bpb=1.083105 time=384.9s
ttt_chunk [1101/1238] bpb=1.083648 time=387.8s
ttt_chunk [1111/1238] bpb=1.083949 time=390.8s
ttt_chunk [1121/1238] bpb=1.084112 time=393.8s
ttt_chunk [1131/1238] bpb=1.083788 time=396.8s
ttt_chunk [1141/1238] bpb=1.083437 time=399.7s
ttt_chunk [1151/1238] bpb=1.083476 time=402.7s
ttt_chunk [1161/1238] bpb=1.083628 time=405.7s
ttt_chunk [1171/1238] bpb=1.083412 time=408.6s
ttt_chunk [1181/1238] bpb=1.082947 time=411.6s
ttt_chunk [1191/1238] bpb=1.083090 time=414.6s
ttt_chunk [1201/1238] bpb=1.083129 time=417.6s
ttt_chunk [1211/1238] bpb=1.082835 time=420.6s
ttt_chunk [1221/1238] bpb=1.082375 time=423.6s
ttt_chunk [1231/1238] bpb=1.082009 time=426.5s
ttt_chunk [1238/1238] bpb=1.082009 time=447.1s
ttt_sliding:done val_loss=2.795371 val_bpb=1.082175 elapsed=448.0s
legal_ttt_hash val_loss:2.79537088 val_bpb:1.08217518 eval_time:448277ms
results_json: {"val_bpb": 1.08217518, "val_loss": 2.79537088, "bytes_total": 15982929, "peak_memory_mib": 34604}
Loading