perf(glm52): MLA decode arena + CUDA graph capture on top of #535 (−36%/layer) + kernel bench#533
perf(glm52): MLA decode arena + CUDA graph capture on top of #535 (−36%/layer) + kernel bench#533n-WN wants to merge 8 commits into
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 63971ca5f8
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| ensure!( | ||
| scratch.sched.len() >= contract.tile_scheduler_metadata_len() | ||
| && scratch.latent.len() >= contract.latent_len(), | ||
| "GLM5.2 MLA scratch was built for a different FlashMLA contract" |
There was a problem hiding this comment.
Reject scratch reuse with a different SM split count
When a scratch arena built for one num_sm_parts is later used with a smaller/tuned contract, this length-only guard can still pass because the buffers are large enough, but sched/splits were generated in Glm52MlaDecodeScratch::new for the old batch_size + num_sm_parts. The subsequent FlashMLA launch then interprets stale schedule metadata under the new contract, which can corrupt the latent attention output instead of failing; store the scratch contract and require the schedule-driving fields to match exactly, or rebuild the metadata when they change.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Right — with the tile schedule hoisted into the scratch, sched/splits are only meaningful under the exact batch_size + num_sm_parts they were generated with, and the length-only >= guard would wave through a smaller mismatched contract straight into corrupted latent output.
Fixed in 75e8330, one level stronger than an equality check: the contract now lives in Glm52MlaDecodeScratch and glm52_mla_decode_forward_into takes no contract parameter — a scratch/contract mismatch is unrepresentable at the call, rather than caught at runtime. (Same commit also hardens the bench around --sm-parts from an adversarial pass: default_num_sm_parts() now reports the device query rather than the override, so the sweep's * label and the cross-split diff baseline stay honest, and graph capture warms the kernel sequence un-captured first.)
Re-verified on H100 after the change: rebuild clean, parity gate passes, graph forward 166.9µs (−41% vs as-is), sweep + p16-vs-default bitwise diff 0.0 unchanged.
Synthesis of the no-bubbles Llama-1B megakernel, ThunderMLA, Mirage MPK, and the TP-Llama-70B kernel, read against openinfer's model lines. Key transfers: CUDA Graphs only eliminate launch cost and leave a measured ~50->78% HBM-bandwidth gap (bubbles + producer/consumer locality); the value ladder for GLM5.2 is arena+graph first (dwarfs fusion), then a ThunderMLA-style partial+reduction fusion ported onto our sparse FlashMLA decode (20-35% precedent at exactly our work size), then per-layer glue kernels on the 7-instruction template. Whole-model fusion is explicitly not worth its maintenance cost (the authors say so). Next step gated on glm52_kernel_bench baseline numbers.
…ist + CUDA graph (-42%/layer) Measured on H100 (bs=1, glm52_kernel_bench, parity-verified vs the alloc-heavy forward): as-is 288us -> 0a arena 218 -> 0b tile-schedule hoist 190 -> 0c graph 166us; re-verified natively on a cuda-compat 575.57.08 host at 158us (-44%). 0a: Glm52MlaDecodeScratch pre-allocates all ~20 intermediates (synchronous cudaMalloc serializes the decode stream, so this drops GPU time not just host). 0b: the FlashMLA tile schedule is data-independent (batch_size + num_sm_parts only) — computed once in scratch::new, dropped from the per-token path. 0c: measure_forward_graph captures the pure-kernel forward via CudaGraphState::run_or_capture; replay removes inter-kernel GPU bubbles. Records the measured non-result: num_sm_parts tuning (isolated flashmla 1.68x at 16 vs default 132, output bitwise-identical) does NOT carry through the graphed forward — the isolated win was the metadata kernel 0b already hoists, and 16 splits underutilize the GPU. Fixes the kernel_bench proj helper (fn, not closure).
bdfba03 to
9e80cfe
Compare
Codex flagged that forward_into's length-only guard lets a scratch built for one num_sm_parts run under a different contract — the pre-computed tile schedule is only meaningful under the exact batch_size+num_sm_parts it was generated with, so a mismatch could corrupt the latent output instead of failing. The contract now lives in Glm52MlaDecodeScratch and forward_into takes no contract parameter: a mismatch is unrepresentable. Also from the adversarial pass: default_num_sm_parts() reports the device query (not a --sm-parts override), so the sweep's '*' label and the cross-split diff baseline stay honest under overrides; graph capture warms the kernel sequence un-captured first (capture aborts on lazy first-touch allocation); the sweep dedups when the default collides with a fixed step.
9e80cfe to
c5173c2
Compare
Rebased onto #535 and re-measured — the scope narrowed honestly. #535 landed the FlashMLA tile-schedule hoist for every path (this PR's step 0b, implemented there as
Glm52MlaSchedMetadata); the branch now reuses that type (the scratch owns aGlm52MlaSchedMetadata, no duplicate plan storage) and its remaining contribution is the two things #535 didn't cover:Measured on top of #535 (
glm52_kernel_bench, bs=1, iters=128, ctx 2048, H100)Glm52MlaDecodeScratch+forward_into)CudaGraphState::run_or_capture)cudaMallocserializes the decode stream — eliminating them drops GPU time, not just host time.Glm52MlaDecodeScratchowns aGlm52MlaSchedMetadata(perf(glm52): 2.2x bs=1 decode — bound MoE rows by token count, persistent workspace, hoisted MLA sched metadata #535's type), so a scratch/plan split-count mismatch is unrepresentable (this was also the Codex round's finding, resolved by construction).Correctness:
verify_scratch_parityasserts the scratch forward is bitwise identical to the bring-up forward before any timing; the (--ignored, H200) MLA oracle gate also replays every position through the scratch forward against the real checkpoint.Measured non-result, recorded in the doc
num_sm_partstuning: the device default (132) over-splits bs=1 — the isolated flashmla stage is 1.68× faster at 16 splits (output bitwise-identical) — but it does not carry through the schedule-hoisted forward (the isolated win was mostly the metadata kernel that is now computed once). This also deprioritizes a ThunderMLA-style partial+combine fusion: theo_accumround-trip it removes is already off the critical path. The sweep (--sm-parts) and a bitwise cross-split diff stay in the bench so the conclusion is re-checkable per hardware.What lands
openinfer-glm52/src/kernel_bench.rs+--features glm52bin: forward/scratch/graph A/B, per-projection and per-stage isolation,num_sm_partssweep, parity gates. Synthetic constant-fp8 weights, no checkpoint needed.Glm52MlaDecodeScratch/glm52_mla_decode_forward_into: the zero-alloc MLA forward, coexisting with the bring-up forward (which stays the parity reference) — the pointer-stable substrate graph capture needs.docs/models/glm52/kernel-perf-decode.md(+ index row) and the megakernel lessons doc, with the post-perf(glm52): 2.2x bs=1 decode — bound MoE rows by token count, persistent workspace, hoisted MLA sched metadata #535 numbers and the pre-perf(glm52): 2.2x bs=1 decode — bound MoE rows by token count, persistent workspace, hoisted MLA sched metadata #535 history for provenance.All numbers re-measured on this exact branch (c5173c2, rebased onto #535). Pre-#535 history: 288 → 166 µs (−42%) with the schedule hoist still in-branch; #535 absorbed that step (as-is dropped 288 → 268), hence today's honest −36%.
🤖 Generated with Claude Code