Skip to content

WIP: Mamba-Attention Hybrid — 18L (15 SSM + 3 GQA), First SSM Entry#1382

Draft
johnlennyt5 wants to merge 32 commits intoopenai:mainfrom
johnlennyt5:arch1/mamba-hybrid
Draft

WIP: Mamba-Attention Hybrid — 18L (15 SSM + 3 GQA), First SSM Entry#1382
johnlennyt5 wants to merge 32 commits intoopenai:mainfrom
johnlennyt5:arch1/mamba-hybrid

Conversation

@johnlennyt5
Copy link
Copy Markdown

WIP: Mamba-Attention Hybrid — 18L (15 Mamba-2 SSM + 3 GQA Attention)

Summary

First SSM-based submission to Parameter Golf. Replaces 15 of 18 layers with Mamba-2 selective state-space blocks while retaining 3
GQA attention layers for global context mixing.

Requesting RunPod credits for GPU training and evaluation.

Architecture

Parameter Value
Total layers 18
Mamba-2 SSM layers 15 (layers 0–11, 15–17)
GQA Attention layers 3 (layers 12–14)
d_model 512
d_state (SSM) 32
d_conv 4
expand 1.5
Total params 27,191,065
Mamba params 19,249,920 (70.8%)
Estimated artifact 7.3 MB (well under 16 MB)

Current Status — CPU Validated, Awaiting GPU

  • 115 unit/integration tests passing on CPU
  • Full hybrid forward/backward verified (all params receive gradients)
  • GPTQ int6 quantization pipeline validated end-to-end (quantize → LZMA → dequantize → forward)
  • Training convergence verified on real FineWeb data
  • BPB evaluation pipeline tested with actual validation tokens
  • All planned ablation configurations pre-validated
  • Artifact size confirmed under budget (7.3 MB LZMA, 8.7 MB headroom)

Why This Is Interesting

OpenAI specifically requested state-space model submissions. This is (to our knowledge) the first Mamba-based entry to Parameter Golf. The hybrid design addresses SSMs' known weakness on associative recall by keeping 3 attention layers at strategic positions,
while leveraging Mamba-2's O(n) sequence processing for the majority of layers.

Key technical contributions:

  • Hybrid dispatch: Mamba and attention layers share a U-Net encoder-decoder with skip connections
  • Muon optimizer for SSM: Mamba matrix params (in_proj, out_proj, dt_proj, c_proj) trained with Newton-Schulz Muon
  • GPTQ for Mamba: Extended Hessian collection and int6 quantization to SSM parameters
  • Gradient checkpointing: Selective checkpointing for Mamba layers (attention uses FA3's built-in memory management)

Test Plan

  • Verify CUDA kernels on RunPod (mamba-ssm + causal-conv1d)
  • Measure step time on 8xH100 (target: <100ms)
  • Full 600s training run, seed 42
  • Architecture ablation (layer count, attention positions, d_state)
  • 3-seed evaluation with Welch's t-test vs SOTA
  • Create submission package with logs

Run Command

MAMBA_LAYERS=0,1,2,3,4,5,6,7,8,9,10,11,15,16,17 NUM_LAYERS=18 \
MAMBA_D_STATE=32 MAMBA_D_CONV=4 MAMBA_EXPAND=1.5 \
MAMBA_MATRIX_LR=0.015 BIGRAM_VOCAB_SIZE=2048 BIGRAM_DIM=128 \
WARMDOWN_ITERS=3500 TARGET_MB=15.9 SEED=42 \
torchrun --standalone --nproc_per_node=8 train_gpt.py

Requirements

mamba-ssm>=2.2.0
causal-conv1d>=1.4.0
sentencepiece

johnlennyt5 and others added 22 commits April 3, 2026 15:29
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… forward, and _selective_scan

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ts all passing

Tests cover: forward shape (2B.1.1), gradient flow (2B.1.2), param count (2B.1.3),
SSM numerical correctness (2B.1.4), causal masking (2B.1.5), determinism (2B.1.6)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…dispatch

- Add mamba_layers, mamba_d_state, mamba_d_conv, mamba_expand params to GPT
- Size parameter banks for attention-only layers (n_attn, not num_layers)
- Create mamba_blocks ModuleList and index maps (mamba_idx_map, attn_idx_map)
- Only create Block objects for non-Mamba layers
- Update _init_weights for new bank sizing
- Pass Mamba args through both base_model and eval_model instantiation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… tests

- Implement _forward_layer() for Mamba/Attention dispatch
- Update forward() and forward_logits() to use hybrid dispatch
- Add tests/test_hybrid_gpt.py: 15 tests covering instantiation, bank sizing,
  forward pass, forward_logits, gradient flow, U-Net skips, 18-layer config
- All 49 tests passing (34 MambaBlock + 15 hybrid GPT)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ck in all test files

Both test files now use _fake_flash_attn (proper GQA-aware CPU attention)
instead of MagicMock, and patch train_gpt.flash_attn_3_func directly to
handle cross-file import ordering.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…line for Mamba params

- _classify_param: add "mamba" category for mamba_blocks.*
- Muon: separate param group for Mamba matrix params (in_proj, out_proj,
  dt_proj, c_proj) with mamba_matrix_lr=0.015
- Adam: Mamba scalar params (A_log, D, conv1d, dt_proj.bias) added to
  scalar_params group
- _unbank/_rebank_state_dict: accept n_attn param for correct bank sizing
- mixed_quantize_int6: include "mamba" in quantized categories
- Logging: show mamba param count, layer config, and mamba_matrix_lr

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…SM output

c_proj.weight now initialized with std=0.01 (was default Kaiming).
All other Mamba inits already correct: A_log=log(arange), D=ones,
out_proj=small normal, dt_proj.bias=inv_softplus of log-uniform.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add _HessianMambaBlock with CastedLinear for in_proj/out_proj
  (enables Hessian collection for GPTQ quantization)
- Update _HessianGPT.__init__ to accept mamba_layers, dispatch to
  Mamba vs attention blocks
- Update _HessianGPT.forward with _forward_layer dispatch
- Update _HessianGPT instantiation in main() to pass mamba params
- Add 7 new tests: unbank/rebank roundtrip, _classify_param coverage,
  _HessianGPT hybrid instantiation/forward/CastedLinear checks
- Total: 56 tests passing

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…j (QAT support)

MambaBlock now uses CastedLinear (instead of nn.Linear) for in_proj and
out_proj, enabling automatic late QAT noise injection when
CastedLinear._qat_enabled is set. Small params (dt_proj, c_proj, A_log,
D, conv1d) stay as regular modules.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…bf16 stability

Hidden state h, discretized dA/dB, and intermediates now computed in
float32 during the sequential scan loop. Output cast back to input
dtype. Prevents precision loss during long-sequence recurrence under
bf16/autocast training.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…A kernels

MambaBlock.forward now has dual paths:
- CUDA fast path: uses selective_scan_fn (fused discretization + scan +
  gating) and causal_conv1d_fn (fused conv + SiLU) when mamba-ssm and
  causal-conv1d packages are available and input is on CUDA
- Sequential fallback: original PyTorch loop for CPU/testing

Imports are conditional (try/except) so the code works without GPU libs.
All 56 CPU tests passing on fallback path.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…raining tests

- Mamba layers use torch.utils.checkpoint during training when
  MAMBA_GRAD_CHECKPOINT=1 (default on), reducing peak memory
- Attention layers stay unchecked (FA3 handles memory internally)
- New tests:
  - test_gradient_checkpoint_mamba: verifies grads flow through checkpoint
  - test_multi_step_loss_decreases: 10-step CPU training validation (2.1.4/2.2.4)
  - test_activation_norms_reasonable: init activation norms in range (2.3.3)
- Total: 59 tests passing

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add @torch.compiler.disable to _selective_scan to prevent graph breaks
  when fullgraph=True compilation encounters the sequential Python loop
  (on GPU the CUDA fast path is used, so this is a safety guard)
- Add test_mamba_block_compile_eager: verifies MambaBlock compiles with
  torch.compile(backend="eager") and produces identical output
- Add test_hybrid_gpt_compile_eager: verifies full hybrid GPT model
  compiles and forward_logits matches uncompiled output
- 61 tests passing

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ep tests

- test_all_params_gradient_hybrid: verifies every Mamba param (A_log, D,
  conv1d, in_proj, out_proj, dt_proj, c_proj) and key attention banks
  (qo_bank, mlp_down_bank) receive non-zero gradients in hybrid model
- test_optimizer_step_updates_params: verifies 5 SGD steps update all
  key weight matrices (Mamba projections, banks, embedding) without NaN/Inf
- 63 tests passing

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…e remaining CPU tests

New tests (10 total):
- 2B.3.3: test_classify_param_all_keys_covered, test_optimizer_param_groups_no_duplicates
- 2B.3.4: test_ema_state_includes_mamba, test_swa_state_includes_mamba
- 2B.5.1: test_attention_only_matches_baseline
- 2B.5.2: test_attention_layers_unchanged_in_hybrid
- 2B.5.3: test_shared_components_unaffected
- 2B.7.1: test_mini_e2e_cpu (init→train→quantize→dequant→rebank→eval pipeline)
- 4.1.2: test_hessian_collection_mamba_keys, test_hessian_collection_functional

73 tests passing (34 MambaBlock + 39 Hybrid GPT)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…line validation

Pre-GPU validation covering:
- 18-layer hybrid model instantiation (production config)
- Forward/backward pass with gradient flow verification
- Optimizer setup (Muon + Adam, Mamba param split)
- Multi-step training convergence on fixed data
- GPTQ quantization pipeline (unbank → int6 → dequant → rebank)
- Artifact serialization with LZMA compression
- Quantized model roundtrip (load + forward)
- All ablation configurations pre-validated
- HessianGPT integration for GPTQ calibration

Total: 104 tests passing (73 existing + 31 new)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- run_smoke_1gpu.sh: Cheapest possible validation on 1xH100 (~$2.50)
- run_mamba_hybrid.sh: Full pipeline (phase1-5) for 8xH100 execution
  - Phase 1: Smoke test + Go/No-Go decision
  - Phase 2: Baseline training (hybrid + SOTA comparison)
  - Phase 3: Ablation experiments (architecture + hyperparams)
  - Phase 4: 3-seed final evaluation with Welch's t-test
  - Phase 5: Submission package creation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…eWeb data

Tests real data pipeline before GPU credits:
- FineWeb data shard loading (100M train + 62M val tokens)
- Tokenizer + BPB lookup tables
- Training on real language data (0.23 nat improvement in 15 steps)
- BPB evaluation on real validation data
- Production-size model: 27,191,065 params (19.2M Mamba + 7.9M other)
- Production artifact: 7.27 MB LZMA (fits 16MB with 8MB headroom!)
- GPTQ Hessian collection + int6 quantization end-to-end
- Autoregressive calibration data generation
- forward_logits on real tokens

Total: 115 tests passing (104 existing + 11 new)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@johnlennyt5 johnlennyt5 marked this pull request as draft April 6, 2026 16:31
johnlennyt5 and others added 7 commits April 6, 2026 14:25
- Add legal score-first TTT (eval_val_sliding_ttt) to train_gpt.py
- Add TTT hyperparameters (TTT_ENABLED, TTT_LR, etc.) to Hyperparameters
- Hook TTT into main() after sliding window eval
- Delete run_smoke_1gpu.sh and run_mamba_hybrid.sh (not part of competition)
- Delete tests/ directory (CPU-only tests, not needed for competition)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…l scan

The @torch.compiler.disable decorator on MambaBlock._selective_scan
breaks fullgraph=True compilation. Disable fullgraph when Mamba layers
are active.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The restore_low_dim_params_to_fp32 function keeps conv1d bias in fp32
while weights are bf16. causal_conv1d_fn requires matching dtypes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Mamba CUDA selective scan can return fp32 due to float() D parameter.
This propagates to attention layers where FlashAttention rejects fp32.
Cast output back to input dtype (bf16) before residual add.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
torch.compile with hybrid Mamba/Attention dispatch causes:
1. Cache limit exceeded (18 layer specializations > limit of 8)
2. Stale compiled FlashAttention graphs with fp32 instead of bf16

Disable compile for training, eval_val_sliding, and final eval when
Mamba layers are active. Pure attention models still get full compile.
Optimization deferred to Epic 3.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When Mamba layers output fp32, it propagates through F.linear via
q_w.to(x.dtype), producing fp32 q/k/v that FlashAttention rejects.
Explicitly cast to bf16 at the flash_attn call site.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
johnlennyt5 and others added 3 commits April 7, 2026 01:29
…A eval

compiled_model is only defined when torch.compile is active (no Mamba).
The diagnostic eval should use model which is always defined.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Re-enable torch.compile for Mamba hybrid models (fullgraph=False, cache_size=64)
- Disable gradient checkpointing by default (MAMBA_GRAD_CHECKPOINT=0)
- Fix Mamba CUDA output dtype: cast to residual dtype (bf16) at source
- Remove fragile .bfloat16() workaround from flash_attn calls
- Add startup assertions for CUDA Mamba/causal-conv1d kernels

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…text

Mamba CUDA kernels (causal_conv1d, selective_scan) cause graph breaks
in torch.compile. After the break, inductor loses autocast dtype tracking,
so fp32 tensors reach flash_attn. The explicit .bfloat16() cast is the
correct fix — it's a safety guarantee, not a workaround.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
taka6745 pushed a commit to taka6745/parameter-golf that referenced this pull request Apr 7, 2026
…penai#1440

EngramLiteHead: learnable hash-embedding n-gram head with sigmoid gates.
Generalizes static n-gram bias (Patch 6) by adding a parallel LEARNABLE
parallel head over hashed bigram + trigram contexts.

PR openai#1440 attributes -0.003 BPB to EngramLite alone within their stack.
~460KB params at vocab=1024 (3072 buckets x 112 dim embed + proj).

Experiments queued:
- EL0_engram_lite_alone (new technique solo)
- EL1_engram_lite_plus_static_ng (stack with Patch 6 static n-gram)
- EL2_engram_lite_seed42 (multi-seed validation)

Also queued for MTP follow-up:
- MTP1_seed42_validation, MTP1_seed999_validation (validate Patch 21 win)
- MTP3_two_heads (test 2-head MTP from DeepSeek-V3 paper)

Mamba-2 hybrid (PR openai#1382) DEFER: 1300+ lines, mamba-ssm + causal-conv1d
external deps, no GPU validation in PR.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@MatoTeziTanka
Copy link
Copy Markdown

Community Review — WIP: Mamba-Attention Hybrid — 18L (15 SSM + 3 GQA), First SSM Entry

BPB: (not parsed — see PR title) | Compliance: LOOKS CLEAN — score-first-per-chunk TTT (legal #1413 dexhunter pattern)

What I found in the code (head SHA 9b775e152041, file train_gpt.py):

The TTT path at line 1276 implements the score-first-per-chunk pattern: each chunk is scored under torch.no_grad() / inference_mode() before the base_model.train() + SGD adaptation runs on that same chunk, with an is_last_chunk guard so the final chunk gets no adaptation pass. This is the structural shape of the current leaderboard's legal frontier (PR #1413 dexhunter, the 1.0828 SP8192 + QK-Gain 5 + Legal TTT entry — verified at its head SHA against the is_last_chunk + torch.no_grad() score-first accumulator pattern).

Per Issue #402 and Issue #677, TTT is legal when each token is scored before the adapter updates on it, and that's what the code does here — chunk ci is scored under weights adapted only on chunks 0..ci-1. No prequant_ttt_adapt_adamw(val_tokens, ...) multi-epoch fine-tune, no scored-region SLOT, no target-in-key n-gram cache.

CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 1.27s, dim=512, layers=11, vocab=1024, code=128982 B, SMOKE_TEST_PASS

Verdict: LOOKS CLEAN.

Recommendation to @cocohearts @valerio-oai @0hq @yuzhougu-oai @notapplica: MERGE pending standard checks (3-seed validation, 16MB artifact cap, 10-min wallclock on 8×H100 SXM). The compliance picture matches the legal reference frontier and no flags were raised by the classification pass.

Auto-classification caveat: this review was drafted by the AST-based classifier against a template derived from manually-reviewed cluster PRs (#1420, #1450, #1487, #1541, #1529, #1533, #1518). If I've misread a subtlety in your eval path — e.g., multi-epoch TTT that I mistook for single-pass, or a target-in-key lookup I missed in a helper function — please flag it and I'll re-run the audit manually.


Reviewed by @MatoTeziTankaThe Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 1.27s, dim=512, layers=11, vocab=1024, code=128982 B, SMOKE_TEST_PASS. Classification via deterministic AST-based classify_prs.py (pattern bank derived from ~65 manually-reviewed PRs earlier in the 2026-04-11 sweep). This review was auto-drafted from a template and spot-checked before posting — if the template misread your code, please call it out so I can iterate the classifier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants