Skip to content

2L QAT Int4-MLP + Int6-Attn#910

Open
Meirzhan05 wants to merge 2 commits intoopenai:mainfrom
Meirzhan05:main
Open

2L QAT Int4-MLP + Int6-Attn#910
Meirzhan05 wants to merge 2 commits intoopenai:mainfrom
Meirzhan05:main

Conversation

@Meirzhan05
Copy link
Copy Markdown

12L QAT Int4-MLP + Int6-Attn

Summary

  • Adds mixed-precision Quantization-Aware Training (int4 for MLP, int6 for attention) via Straight-Through Estimator from the start of training
  • Uses the ~3MB byte savings from int4 MLP compression to fund a 12th transformer layer (vs 11 in SOTA)
  • Halves the sliding window eval stride from 64 → 32 for better per-token context

Changes

Component SOTA #1 (1.1194) This PR
Layers 11 12
MLP precision int6 int4 (QAT)
Attn precision int6 int6 (QAT)
Eval stride 64 32
QAT Late (last ~10%) Full (from step 0)

Implementation

QAT is applied directly in MLP.forward and CausalSelfAttention.forward on the banked weight tensors via _fake_quantize_ste (row-wise scale, STE gradient). Clip ranges: mlp_clip=7 (int4), attn_clip=31 (int6). Post-training quantization uses GPTQ-lite clip search with the same ranges.

Expected Results

Change Est. bpb delta
12th layer -0.002 to -0.003
QAT reduced quant penalty -0.001 to -0.002
Stride 32 eval -0.001
Target ~1.114–1.117

Test Plan

  • Verify artifact size < 16MB after LZMA compression
  • Verify training completes in < 600s on 8xH100
  • Confirm val_bpb improvement over 1.1194 baseline
  • Run decompression + inference sanity check

@MatoTeziTanka
Copy link
Copy Markdown

Community Review — 2L QAT Int4-MLP + Int6-Attn

BPB: (not parsed — see PR title) | Compliance: LOOKS CLEAN — score-first-per-chunk TTT (legal #1416/#1423 pattern)

What I found in the code (head SHA 8dd90897ab7b, file records/track_10min_16mb/2026-03-25_QAT_Int4MLP_12L/train_gpt.py):

The TTT path at line 1098 implements the score-first-per-chunk pattern: each chunk is scored under torch.no_grad() / inference_mode() before the base_model.train() + SGD adaptation runs on that same chunk, with an is_last_chunk guard so the final chunk gets no adaptation pass. This is the structural shape the legal frontier uses (PRs #1416 erichroepke, #1423 aryanbhosale).

Per Issue #402 and Issue #677, TTT is legal when each token is scored before the adapter updates on it, and that's what the code does here — chunk ci is scored under weights adapted only on chunks 0..ci-1. No prequant_ttt_adapt_adamw(val_tokens, ...) multi-epoch fine-tune, no scored-region SLOT, no target-in-key n-gram cache.

CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 0.04s, dim=512, layers=12, vocab=1024, code=91173 B, SMOKE_TEST_PASS

Verdict: LOOKS CLEAN.

Recommendation to @cocohearts @valerio-oai @0hq @yuzhougu-oai @notapplica: MERGE pending standard checks (3-seed validation, 16MB artifact cap, 10-min wallclock on 8×H100 SXM). The compliance picture matches the legal reference frontier and no flags were raised by the classification pass.

Auto-classification caveat: this review was drafted by the AST-based classifier against a template derived from manually-reviewed cluster PRs (#1420, #1450, #1487, #1541, #1529, #1533, #1518). If I've misread a subtlety in your eval path — e.g., multi-epoch TTT that I mistook for single-pass, or a target-in-key lookup I missed in a helper function — please flag it and I'll re-run the audit manually.


Reviewed by @MatoTeziTankaThe Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 0.04s, dim=512, layers=12, vocab=1024, code=91173 B, SMOKE_TEST_PASS. Classification via deterministic AST-based classify_prs.py (pattern bank derived from ~65 manually-reviewed PRs earlier in the 2026-04-11 sweep). This review was auto-drafted from a template and spot-checked before posting — if the template misread your code, please call it out so I can iterate the classifier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants