Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# [Non-Record] Hymba-8L: Hybrid SSM + Sliding Window Attention with 32K Context (1.1470 BPB)

## Summary

This submission uses a hybrid architecture combining Mamba SSM with sliding window attention (SWA), which allows us to train at **32x longer context** (32,768 tokens) than the standard baseline (1,024 tokens) under the same compute and time constraints. Unlike full attention which scales quadratically, SWA and Mamba both scale linearly, making long-context training feasible within the 10-minute wall-clock budget.

Building on our previous Hymba submission (1.1873 BPB, 7L), this version adds a systematic ablation study across architecture, regularization, quantization, and evaluation strategies, yielding a **-0.040 BPB improvement**.

## Results

| Seed | val_bpb | val_loss | Steps | Artifact Size |
|------|---------|----------|-------|---------------|
| 1337 | 1.1474 | 1.9374 | 6,621 | 15.7 MB |
| 42 | 1.1469 | 1.9366 | 6,620 | 15.6 MB |
| 7 | 1.1468 | 1.9363 | 6,606 | 15.3 MB |
| **Mean** | **1.1470 ± 0.0003** | | | |

- Training: 600s on 8xH100 SXM, ~90.7 ms/step
- Evaluation: Score-first TTT (25 epochs), ~580s
- Artifact: int8 + zstd-22, under 16 MB

## Key Improvements Over Previous Submission (1.1873 BPB)

### 1. 8 Layers (up from 7)
Added an 8th layer for more model capacity. Despite slower per-step time (~180 vs ~160 ms on 4xH100), the quality improvement (-0.009 BPB) more than compensates. Artifact still fits under 16 MB with int8 + WD=0.14.

### 2. SWA-1024 (up from SWA-512)
Doubled the sliding window attention window from 512 to 1024 tokens. Each token now attends to 1024 previous tokens via local attention while Mamba handles global context through recurrent state. This yielded another -0.009 BPB with only ~6 ms/step overhead.

### 3. SSM State=4 (down from 8)
Counter-intuitively, reducing the Mamba state dimension from 8 to 4 improved both speed and quality. With SWA-1024 handling local patterns, the SSM needs less recurrent state. This gave ~8 ms/step speedup and -0.002 BPB improvement.

### 4. Untied Embeddings
Separate lm_head instead of tying with tok_emb. Faster training (~159 vs ~176 ms/step with tied) and better BPB. The speed gain alone yields ~200 more steps in 10 minutes.

### 5. High Weight Decay (WD=0.15) + int8 Quantization
Higher WD acts as strong regularization, improving pre-quant BPB monotonically up to WD=0.14. WD=0.15 is used to ensure all seeds fit under 16 MB with a safety margin. Combined with full int8 quantization (not int6), this gives the best post-quant BPB while fitting under 16 MB. Key insight: the model is overfitting at this training duration, so aggressive regularization helps generalization.

### 6. Aggressive Warmdown (7000 iters)
Extended cosine LR warmdown from 3000 to 7000 iterations (wall-clock based). The model benefits greatly from prolonged LR decay, with a large BPB drop during the warmdown phase. This also reduces step time late in training due to smaller weight updates.

### 7. TTT: 25 Epochs, No Freeze
Increased test-time training from 3 epochs with 2 frozen blocks to 25 epochs with all blocks unfrozen. The cosine LR decay in TTT prevents catastrophic forgetting even without freezing. Score-first TTT remains the evaluation strategy.

### 8. GRAD_ACCUM_STEPS=1
Eliminated gradient accumulation overhead by processing the full local batch in a single micro-step per GPU. This saves ~6 ms/step, yielding ~200 more training steps.

## Hymba Hybrid Architecture

Based on the Hymba paper (arXiv:2411.13676), each block runs attention and Mamba **in parallel** within a single layer:
- Attention branch: Q projection + shared KV projection, GQA (8 heads, 4 KV heads), RoPE, QK-norm, SWA-1024
- Mamba branch: Selective scan with causal 1D convolution, gated output, state dim=4
- Learned merge: sigmoid-gated weighted sum of both branches
- Post-merge: output projection + residual with learned scale

Additional: LeakyReLU(0.9)^2 MLP, SmearGate + BigramHash embedding, U-Net skip connections, EMA(0.997).

## Ablation Summary

Over 50 ablation experiments were conducted across two days. Key findings:
- **Architecture**: 8L > 7L, SWA-1024 > 512 > 256, SSM_STATE=4 > 8 > 16, untied embeddings
- **Regularization**: WD=0.14 optimal for int8 under 16MB, WD=0.12 better BPB but over budget
- **Quantization**: int8 with high WD beats mixed int8/int6, GPTQ_LITE=0 works fine
- **Training**: warmdown=7000, GRAD_ACCUM_STEPS=1, 524K batch, EMA_EVERY=10, Muon steps=5
- **TTT**: LR=0.002, 25 epochs no-freeze, cosine decay
- **Not helpful**: XSA, Partial RoPE (16/64), LZMA compression, smaller/larger batch sizes

## Run Command

```bash
SEED=1337 SLIDING_WINDOW=1024 SWA_GLOBAL_LAYERS=none TRAIN_SEQ_LEN=32768 \
NUM_LAYERS=8 MLP_MULT=4 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 \
MATRIX_LR=0.02 SCALAR_LR=0.02 WARMDOWN_ITERS=7000 WARMDOWN_SHAPE=cosine \
EVAL_STRIDE=0 EVAL_BATCH_SEQS=4 GPTQ_LITE=0 QUANT_BITS=8 \
HYMBA_EXPAND=1 HYMBA_SSM_STATE=4 \
USE_SMEARGATE=1 USE_BIGRAM_HASH=1 TIE_EMBEDDINGS=0 LEAKY_RELU_SLOPE=0.9 \
WEIGHT_DECAY=0.15 EMA_EVERY=10 GRAD_ACCUM_STEPS=1 \
TTT_ENABLED=1 TTT_LR=0.002 TTT_EPOCHS=25 TTT_CHUNK_TOKENS=524288 \
TTT_FREEZE_BLOCKS=0 TTT_BATCH_SEQS=4 \
MAX_WALLCLOCK_SECONDS=600 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## Dependencies

```bash
export PATH=/usr/local/cuda/bin:$PATH
pip install --no-build-isolation --break-system-packages mamba-ssm causal-conv1d zstandard sentencepiece
```

Requires PyTorch >= 2.5 for flex_attention (sliding window). Tested on PyTorch 2.8.0+cu128.
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
W0402 00:40:53.132000 2084 torch/distributed/run.py:774]
W0402 00:40:53.132000 2084 torch/distributed/run.py:774] *****************************************
W0402 00:40:53.132000 2084 torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0402 00:40:53.132000 2084 torch/distributed/run.py:774] *****************************************
logs/e5aa5e77-7ed1-41d9-bc71-db78c0ccb6a5.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:61997056
model_params:31357512
world_size:8 grad_accum_steps:1
attention_mode:gqa num_heads:8 num_kv_heads:4
num_layers:8 mlp_mult:4 kv_share_every:1 meta_tokens:0 sliding_window:1024
tie_embeddings:False embed_lr:0.6 head_lr:0.008 matrix_lr:0.02 scalar_lr:0.02
train_batch_tokens:524288 train_seq_len:32768 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1481: UserWarning: Dynamo does not know how to trace the builtin `selective_scan_cuda.PyCapsule.fwd.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1481: UserWarning: Dynamo does not know how to trace the builtin `selective_scan_cuda.PyCapsule.fwd.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1481: UserWarning: Dynamo does not know how to trace the builtin `selective_scan_cuda.PyCapsule.fwd.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1481: UserWarning: Dynamo does not know how to trace the builtin `selective_scan_cuda.PyCapsule.fwd.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1481: UserWarning: Dynamo does not know how to trace the builtin `selective_scan_cuda.PyCapsule.fwd.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1481: UserWarning: Dynamo does not know how to trace the builtin `selective_scan_cuda.PyCapsule.fwd.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1481: UserWarning: Dynamo does not know how to trace the builtin `selective_scan_cuda.PyCapsule.fwd.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1481: UserWarning: Dynamo does not know how to trace the builtin `selective_scan_cuda.PyCapsule.fwd.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9315 val_bpb:4.1051 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9315 train_time:94ms step_avg:94.50ms
step:2/20000 train_loss:6.3263 train_time:181ms step_avg:90.46ms
step:3/20000 train_loss:5.2470 train_time:272ms step_avg:90.68ms
step:4/20000 train_loss:5.0697 train_time:363ms step_avg:90.69ms
step:5/20000 train_loss:4.7019 train_time:453ms step_avg:90.69ms
step:6/20000 train_loss:4.9286 train_time:544ms step_avg:90.69ms
step:7/20000 train_loss:4.5586 train_time:635ms step_avg:90.69ms
step:8/20000 train_loss:4.5049 train_time:726ms step_avg:90.69ms
step:9/20000 train_loss:4.4422 train_time:816ms step_avg:90.69ms
step:10/20000 train_loss:4.4506 train_time:916ms step_avg:91.59ms
step:200/20000 train_loss:2.6481 train_time:18059ms step_avg:90.29ms
step:400/20000 train_loss:2.2019 train_time:36067ms step_avg:90.17ms
step:600/20000 train_loss:2.4352 train_time:54078ms step_avg:90.13ms
step:800/20000 train_loss:2.1768 train_time:72092ms step_avg:90.11ms
step:1000/20000 train_loss:2.2964 train_time:90115ms step_avg:90.12ms
step:1000/20000 val_loss:2.2466 val_bpb:1.3305 train_time:90120ms step_avg:90.12ms
step:1200/20000 train_loss:2.3181 train_time:108331ms step_avg:90.28ms
step:1400/20000 train_loss:2.3548 train_time:126560ms step_avg:90.40ms
step:1600/20000 train_loss:2.0210 train_time:144787ms step_avg:90.49ms
step:1800/20000 train_loss:2.1315 train_time:163327ms step_avg:90.74ms
step:2000/20000 train_loss:2.1750 train_time:181365ms step_avg:90.68ms
step:2000/20000 val_loss:2.1597 val_bpb:1.2791 train_time:181370ms step_avg:90.68ms
step:2200/20000 train_loss:2.0051 train_time:199612ms step_avg:90.73ms
step:2400/20000 train_loss:2.1220 train_time:217638ms step_avg:90.68ms
step:2600/20000 train_loss:2.3470 train_time:235680ms step_avg:90.65ms
step:2800/20000 train_loss:2.1597 train_time:253729ms step_avg:90.62ms
step:3000/20000 train_loss:2.1332 train_time:271770ms step_avg:90.59ms
step:3000/20000 val_loss:2.1015 val_bpb:1.2446 train_time:271774ms step_avg:90.59ms
step:3200/20000 train_loss:2.0857 train_time:289796ms step_avg:90.56ms
step:3400/20000 train_loss:2.0564 train_time:307817ms step_avg:90.53ms
step:3600/20000 train_loss:1.9898 train_time:326135ms step_avg:90.59ms
step:3800/20000 train_loss:2.0978 train_time:344519ms step_avg:90.66ms
step:4000/20000 train_loss:2.0531 train_time:362907ms step_avg:90.73ms
step:4000/20000 val_loss:2.0412 val_bpb:1.2089 train_time:362914ms step_avg:90.73ms
step:4200/20000 train_loss:2.0314 train_time:381190ms step_avg:90.76ms
step:4400/20000 train_loss:1.9528 train_time:399857ms step_avg:90.88ms
step:4600/20000 train_loss:1.8126 train_time:418261ms step_avg:90.93ms
step:4800/20000 train_loss:2.0926 train_time:436476ms step_avg:90.93ms
step:5000/20000 train_loss:1.8367 train_time:454788ms step_avg:90.96ms
step:5000/20000 val_loss:1.9764 val_bpb:1.1705 train_time:454816ms step_avg:90.96ms
step:5200/20000 train_loss:1.9892 train_time:473000ms step_avg:90.96ms
step:5400/20000 train_loss:1.9889 train_time:491231ms step_avg:90.97ms
step:5600/20000 train_loss:1.9759 train_time:509439ms step_avg:90.97ms
step:5800/20000 train_loss:1.9218 train_time:527650ms step_avg:90.97ms
step:6000/20000 train_loss:2.0098 train_time:545863ms step_avg:90.98ms
step:6000/20000 val_loss:1.9253 val_bpb:1.1402 train_time:545872ms step_avg:90.98ms
step:6200/20000 train_loss:1.8672 train_time:564422ms step_avg:91.04ms
step:6400/20000 train_loss:1.9434 train_time:582790ms step_avg:91.06ms
step:6590/20000 val_loss:1.9176 val_bpb:1.1357 train_time:600072ms step_avg:91.06ms
stopping_early: wallclock_cap train_time:600072ms step:6590/20000
peak memory allocated: 18320 MiB reserved: 18748 MiB
ema:applying EMA weights
Serialized model: 123061651 bytes
Code size: 85305 bytes
Total submission size: 123146956 bytes
Serialized model int8+zstd-22: 15679725 bytes (payload:32444704 raw_torch:32512335 payload_ratio:3.79x)
Total submission size int8+zstd-22: 15765030 bytes
ttt:start chunks=119 chunk_tokens=524288 seq_len=32768 ttt_lr=0.002 ttt_epochs=25 freeze_blocks=0
ttt:params unfrozen=31357512 frozen=0
ttt_chunk [1/119] bpb=1.200632 time=11.5s
ttt_chunk [11/119] bpb=1.152016 time=60.7s
ttt_chunk [21/119] bpb=1.158212 time=109.9s
ttt_chunk [31/119] bpb=1.153099 time=159.0s
ttt_chunk [41/119] bpb=1.156847 time=208.2s
ttt_chunk [51/119] bpb=1.150066 time=257.4s
ttt_chunk [61/119] bpb=1.148734 time=306.6s
ttt_chunk [71/119] bpb=1.148728 time=355.7s
ttt_chunk [81/119] bpb=1.145716 time=405.0s
ttt_chunk [91/119] bpb=1.144754 time=454.0s
ttt_chunk [101/119] bpb=1.145109 time=503.0s
ttt_chunk [111/119] bpb=1.144611 time=552.0s
ttt_chunk [119/119] bpb=1.143314 time=586.3s
ttt:done val_loss=1.937354 val_bpb=1.147383 elapsed=588.3s
final_int8_zlib_roundtrip val_loss:1.9374 val_bpb:1.1474 eval_mode:score_first_ttt eval_time:588286ms
final_int8_zlib_roundtrip_exact val_loss:1.93735364 val_bpb:1.14738259
Loading