Skip to content

Non-record: Polar STE QAT for structural weights#1154

Open
LucasErcolano wants to merge 8 commits intoopenai:mainfrom
LucasErcolano:codex/polar-ste-weights
Open

Non-record: Polar STE QAT for structural weights#1154
LucasErcolano wants to merge 8 commits intoopenai:mainfrom
LucasErcolano:codex/polar-ste-weights

Conversation

@LucasErcolano
Copy link
Copy Markdown

@LucasErcolano LucasErcolano commented Mar 30, 2026

Summary

Stacked on top of #1149 and the Triton backend follow-up PR.

This PR extends the same non-record submission with Polar STE QAT for structural weights and a matching polar export path:

  • QAT_SCHEME=polar for CastedLinear weight fake-quant during training
  • WEIGHT_QUANT_SCHEME=polar for large 2D structural tensors at export time
  • polar row encode/decode helpers shared by QAT and serialization
  • roundtrip dequantization wired into the existing final validation path
  • a fresh-process artifact-load harness for checking autonomous reconstruction from polar+zlib
  • a distributed-safe final KV-eval path so WORLD_SIZE>1 runs can train under DDP without deadlocking at final eval

Local Validation

Short RTX 3090 smoke training (ITERATIONS=8, TRAIN_SEQ_LEN=256, QAT_SCHEME=polar, WEIGHT_QUANT_SCHEME=polar):

  • train loss after early spike continued to decrease: 17.3560 -> 10.4964 from steps 2 -> 8
  • final teacher-forced val_bpb: 5.5849
  • final autoregressive qjl_triton smoke eval: 70.65 tok/s on 256 validation tokens
  • serialized polar+zlib artifact size: 14,510,929 bytes

Extended RTX 3090 local stress test (ITERATIONS=640, TRAIN_SEQ_LEN=256, VAL_MAX_TOKENS=32768):

  • teacher-forced val_bpb: 4.0166 -> 2.3356 -> 2.1329 -> 2.0218 -> 1.9775 -> 1.9278 -> 1.8482 -> 1.8164 -> 1.8250 -> 1.7893 -> 1.7757
  • crossed below 2.0 at step 256
  • no late loss spike or gradient collapse; a small wobble at step 512 recovered by the end of the run
  • final autoregressive qjl_triton eval on 1024 validation tokens: val_bpb=2.3861, 73.81 tok/s
  • serialized polar+zlib artifact size after the long run: 14,782,032 bytes

Artifact isolation / autonomy check (artifact_isolation_check.py, fresh Python process):

  • loaded only final_model.polar.ptz and generated a continuation from the prompt La cuantizacion polar
  • sample continuation: s, and the polars are the best p
  • profiled 2048 autoregressive tokens with qjl_triton
  • allocator stabilized after filling the 256-token context window: steady_state_allocated_growth=9216 bytes, peak_memory_allocated=88,040,960 bytes, peak_memory_reserved=106,954,752 bytes

Hopper validation:

  • 1xH100 template smoke (MAX_WALLCLOCK_SECONDS=60) compiled and ran successfully with qjl_triton; no VRAM anomalies (peak memory allocated: 1987 MiB)
  • eager-mode Hopper decode favored native qjl over qjl_triton on this workload (115.51 tok/s vs 110.13 tok/s on a 1024-token isolated profile)
  • ENABLE_TORCH_COMPILE=1 was functional but incurred ~200s of compile overhead, so the intended record-run configuration keeps ENABLE_TORCH_COMPILE=0
  • 8xH100 DDP smoke (torchrun --nproc_per_node=8, MAX_WALLCLOCK_SECONDS=60, KV_QUANT_BACKEND=qjl) completed without deadlock; final rank-0-only autoregressive eval also completed successfully
  • 8xH100 smoke result: 319 train steps in 60.176s, teacher-forced val_bpb=1.8377, final autoregressive qjl eval val_bpb=2.1360, 92.55 tok/s

Additional sanity checks:

  • polar row roundtrip MSE on random weights: ~3.6e-4
  • KV backend selftests still pass after the weight-quantization changes
  • fixed a RoPE cache interaction where validation under torch.inference_mode() could leave cached inference tensors behind and break later training steps

Review Notes

This branch is intentionally stacked. Until lower PRs merge, the review focus is the top commit sequence:

  • ba91df0 Add Polar STE QAT and polar weight export
  • b393705 Fix RoPE eval cache and add artifact isolation harness
  • ed8f388 Make final KV eval safe under distributed training

@LucasErcolano
Copy link
Copy Markdown
Author

Extended local stress test on 1x RTX 3090 completed.

  • ITERATIONS=640, TRAIN_SEQ_LEN=256, VAL_MAX_TOKENS=32768
  • val_bpb crossed below 2.0 at step 256 and finished at 1.7757
  • no late loss spike or gradient collapse; a small wobble at step 512 recovered by the end of the run
  • final autoregressive qjl_triton eval on 1024 validation tokens: val_bpb=2.3861, 73.81 tok/s

@LucasErcolano
Copy link
Copy Markdown
Author

Full 8xH100 Hopper Validation (10-minute wallclock)

We executed a full-length run on the official target hardware to validate the distributed scaling and wallclock budget of this Polar + QJL stack.

  • Hardware: 8x H100 80GB (WORLD_SIZE=8)
  • Config: KV_QUANT_BACKEND=qjl, ENABLE_TORCH_COMPILE=0
    Eager mode was strictly better than the current Triton decode path for batch=1 autoregressive decode on Hopper, and it also avoids the ~200s compile overhead.
  • Time: 592,209 ms
    This includes the wallclock budget fix now pushed in 8979432, which subtracts pre-training setup overhead before entering the main training loop.
  • Steps: 3382
  • VRAM: 1933 MiB peak allocated, 2080 MiB peak reserved
  • Metrics: teacher-forced val_bpb=1.4594 -> autoregressive val_bpb=2.12830032
  • Throughput: 93.51 tok/s in final autoregressive qjl eval
  • Artifact: 14,751,006 bytes (polar+zlib)

Conclusion: the infrastructure, synchronization, export/reload path, and end-to-end math are all proven on Hopper. The remaining problem is model quality, not systems stability: the current gap between teacher-forced and autoregressive BPB indicates the topology still needs stronger structural regularization to handle decode-time asymmetric quantization noise.

@LucasErcolano
Copy link
Copy Markdown
Author

Follow-up local KV-cache result on top of this PR: pushed 97f028f (Add recent-key hybrid QJL eval backend).

This adds qjl_recentk, a hybrid KV evaluator that keeps a short recent suffix of K in fp16 and spills the older prefix to the existing QJL representation, while leaving V on the same grouped quantized path as before.

Local RTX 3090 probe on the recurrent 3x3 d512 recipe (matrix_lr=scalar_lr=0.05, 90s budget, same trained weights for both eval backends):

  • qjl: val_bpb=2.34748538, 72.04 tok/s, peak_cache_bytes=387072
  • qjl_recentk with KV_RECENT_FP16_TOKENS=8: val_bpb=2.26498743, 65.93 tok/s, peak_cache_bytes=421056

Short sweep over the recent exact-key window:

  • 4: 2.26712904
  • 8: 2.26498743
  • 16: 2.27225724
  • 32: 2.29400939
  • 64: 2.30457652

So the best local point so far is a very small exact-key suffix (8 tokens), which suggests the autoregressive gap is dominated by the most recent key geometry rather than the quantized value path.

I have not re-run Hopper validation for qjl_recentk yet because the available RunPod credits were exhausted after the earlier H100/DDP validation. That remains the next systems check before treating this as a serious 8xH100 candidate:

  • 1xH100 smoke to confirm no Hopper-specific decode regression
  • 8xH100 DDP smoke to confirm the existing rank-0-only final eval path remains deadlock-free with qjl_recentk

The current code path itself does not add any new distributed collectives; it only changes rank-local KV-cache layout and score computation.

@MatoTeziTanka
Copy link
Copy Markdown

MatoTeziTanka commented Apr 11, 2026

[RETRACTED 2026-04-11] — This IMPORT_FAIL was a false positive. Root cause: sibling module exists in same records/ folder; runner sys.path bug. Your code is not broken. See correction below: #1154 (comment)


Community Review — Non-record: Polar STE QAT for structural weights

Compliance: NEEDS AUTHOR ACTION — train_gpt.py fails to import on CT2038 (Python 3.10 / torch 2.10.0+cpu)

What I found: The CPU smoke test on CT2038 (proteus-engine, 128 GB RAM, Triton 3.6.0, flash_attn stub, cutlass_evt_fusion stub) failed at the import step with:

ModuleNotFoundError: No module named 'triton_kv_ops'

A few of the common patterns I've seen for this class of error in the 2026-04-11 sweep:

Recommendation: Could you run python3 -c "import py_compile; py_compile.compile('train_gpt.py')" on your records-folder train_gpt.py under Python 3.10 specifically? The eval image is Python 3.10 per Issue #17 / the README, so any parse error on 3.10 blocks the submission at import time before any of the scored-eval logic runs.

Once the parse/import issue is fixed, I'll re-run the compliance audit through the normal pipeline. No other flags identified yet because the audit halts at the import step.


Reviewed by @MatoTeziTankaThe Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): IMPORT_FAIL — ModuleNotFoundError: No module named 'triton_kv_ops'. Classification via classify_prs.py AST-based classifier; full compliance audit deferred until the import issue is resolved. Auto-drafted from a template and spot-checked before posting.

@MatoTeziTanka
Copy link
Copy Markdown

Retraction — this IMPORT_FAIL was a bug in my smoke runner

Sorry @LucasErcolano, this one's on me. I re-audited the IMPORT_FAIL I posted above and it was a false positive — the fault is in how my CPU smoke runner set up sys.path, not in your code.

What happened:

The runner imported your records/track_non_record_16mb/2026-03-30_KVCache_QJL_Polar_Turbo_1x3090/train_gpt.py with only the script's folder implicitly on sys.path, so when your file did from triton_kv_ops import ... it couldn't resolve the sibling triton_kv_ops.py that lives in the same 2026-03-30_KVCache_QJL_Polar_Turbo_1x3090/ directory. The error I reported — ModuleNotFoundError: No module named 'triton_kv_ops' — looked like a missing file, but I re-checked the head SHA 97f028f and records/track_non_record_16mb/2026-03-30_KVCache_QJL_Polar_Turbo_1x3090/triton_kv_ops.py is right there, committed to the PR, next to train_gpt.py.

Verified at head 97f028f:

records/track_non_record_16mb/2026-03-30_KVCache_QJL_Polar_Turbo_1x3090/triton_kv_ops.py   ← sibling module, exists
records/track_non_record_16mb/2026-03-30_KVCache_QJL_Polar_Turbo_1x3090/train_gpt.py   ← imports it

On the real eval image (Python 3.10, records/*/ as the working dir), this import resolves correctly because the records folder ends up on sys.path via the standard cwd-driven import or via the eval harness's per-record entry point.

Your PR is not broken by this error. I'm retracting the IMPORT_FAIL classification. I'll re-queue the full compliance audit (BPB check, n-gram / TTT / SLOT flags, etc.) on the current head and post findings separately.

Again — sorry for the noise. These community reviews only work if I actually read what I'm reviewing, and I didn't in this case.

@MatoTeziTanka
Copy link
Copy Markdown

Community Review — Non-record: Polar STE QAT for structural weights

BPB: 2.3861 | Compliance: LOOKS CLEAN — pure-neural submission, no TTT/SLOT/n-gram-cache

What I found in the code (head SHA 97f028fcc4a9, file records/track_non_record_16mb/2026-03-30_KVCache_QJL_Polar_Turbo_1x3090/train_gpt.py):

Static code review found no TTT adaptation function, no SLOT optimization loop, no n-gram-cache class, and no pre-quant val-token fine-tune. The eval path uses the standard sliding-window stride-64 pattern. The submission is a pure-neural architecture iteration on the standard SP1024/SP4096/SP8192 baseline.

CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 0.35s, dim=512, layers=9, vocab=1024, code=130060 B, SMOKE_TEST_PASS

Verdict: LOOKS CLEAN.

Recommendation to @cocohearts @valerio-oai @0hq @yuzhougu-oai @notapplica: MERGE pending the usual record-track checks (3-seed validation, under-16MB artifact cap, ≤600s train + ≤600s eval on 8×H100 SXM). No compliance flags from the classification pass — this looks like a clean pure-neural iteration on the standard baseline.

Auto-classification caveat: this review was drafted by the AST-based classifier. If there's a non-standard eval mechanism (logit postprocessing, hedge mixing, etc.) that I missed because it's factored into a helper file or a non-standard function name, please flag it and I'll re-run the audit manually.


Reviewed by @MatoTeziTankaThe Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): import OK in 0.35s, dim=512, layers=9, vocab=1024, code=130060 B, SMOKE_TEST_PASS. Classification via deterministic AST-based classify_prs.py (pattern bank derived from ~65 manually-reviewed PRs earlier in the 2026-04-11 sweep). This review was auto-drafted from a template and spot-checked before posting — if the template misread your code, please call it out so I can iterate the classifier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants