Skip to content

perf(glm52): MLA decode arena + CUDA graph capture on top of #535 (−36%/layer) + kernel bench#533

Open
n-WN wants to merge 8 commits into
openinfer-project:mainfrom
n-WN:feat/glm52-kernel-bench
Open

perf(glm52): MLA decode arena + CUDA graph capture on top of #535 (−36%/layer) + kernel bench#533
n-WN wants to merge 8 commits into
openinfer-project:mainfrom
n-WN:feat/glm52-kernel-bench

Conversation

@n-WN

@n-WN n-WN commented Jul 3, 2026

Copy link
Copy Markdown
Contributor

Rebased onto #535 and re-measured — the scope narrowed honestly. #535 landed the FlashMLA tile-schedule hoist for every path (this PR's step 0b, implemented there as Glm52MlaSchedMetadata); the branch now reuses that type (the scratch owns a Glm52MlaSchedMetadata, no duplicate plan storage) and its remaining contribution is the two things #535 didn't cover:

Measured on top of #535 (glm52_kernel_bench, bs=1, iters=128, ctx 2048, H100)

stage gpu / wall cumulative
as-is forward (already includes #535's hoisted schedule) 267.6 / 278.0 µs
0a MLA-layer allocation arena (Glm52MlaDecodeScratch + forward_into) 194.7 / 205.1 µs −73 µs
0c CUDA Graph capture (CudaGraphState::run_or_capture) 167.9 / 178.3 µs −100 µs (−36%)

Correctness: verify_scratch_parity asserts the scratch forward is bitwise identical to the bring-up forward before any timing; the (--ignored, H200) MLA oracle gate also replays every position through the scratch forward against the real checkpoint.

Measured non-result, recorded in the doc

num_sm_parts tuning: the device default (132) over-splits bs=1 — the isolated flashmla stage is 1.68× faster at 16 splits (output bitwise-identical) — but it does not carry through the schedule-hoisted forward (the isolated win was mostly the metadata kernel that is now computed once). This also deprioritizes a ThunderMLA-style partial+combine fusion: the o_accum round-trip it removes is already off the critical path. The sweep (--sm-parts) and a bitwise cross-split diff stay in the bench so the conclusion is re-checkable per hardware.

What lands

All numbers re-measured on this exact branch (c5173c2, rebased onto #535). Pre-#535 history: 288 → 166 µs (−42%) with the schedule hoist still in-branch; #535 absorbed that step (as-is dropped 288 → 268), hence today's honest −36%.

🤖 Generated with Claude Code

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 63971ca5f8

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread openinfer-glm52/src/mla_decode.rs Outdated
Comment on lines +529 to +532
ensure!(
scratch.sched.len() >= contract.tile_scheduler_metadata_len()
&& scratch.latent.len() >= contract.latent_len(),
"GLM5.2 MLA scratch was built for a different FlashMLA contract"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reject scratch reuse with a different SM split count

When a scratch arena built for one num_sm_parts is later used with a smaller/tuned contract, this length-only guard can still pass because the buffers are large enough, but sched/splits were generated in Glm52MlaDecodeScratch::new for the old batch_size + num_sm_parts. The subsequent FlashMLA launch then interprets stale schedule metadata under the new contract, which can corrupt the latent attention output instead of failing; store the scratch contract and require the schedule-driving fields to match exactly, or rebuild the metadata when they change.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right — with the tile schedule hoisted into the scratch, sched/splits are only meaningful under the exact batch_size + num_sm_parts they were generated with, and the length-only >= guard would wave through a smaller mismatched contract straight into corrupted latent output.

Fixed in 75e8330, one level stronger than an equality check: the contract now lives in Glm52MlaDecodeScratch and glm52_mla_decode_forward_into takes no contract parameter — a scratch/contract mismatch is unrepresentable at the call, rather than caught at runtime. (Same commit also hardens the bench around --sm-parts from an adversarial pass: default_num_sm_parts() now reports the device query rather than the override, so the sweep's * label and the cross-split diff baseline stay honest, and graph capture warms the kernel sequence un-captured first.)

Re-verified on H100 after the change: rebuild clean, parity gate passes, graph forward 166.9µs (−41% vs as-is), sweep + p16-vs-default bitwise diff 0.0 unchanged.

n-WN added 7 commits July 4, 2026 00:23
Synthesis of the no-bubbles Llama-1B megakernel, ThunderMLA, Mirage MPK,
and the TP-Llama-70B kernel, read against openinfer's model lines. Key
transfers: CUDA Graphs only eliminate launch cost and leave a measured
~50->78% HBM-bandwidth gap (bubbles + producer/consumer locality); the
value ladder for GLM5.2 is arena+graph first (dwarfs fusion), then a
ThunderMLA-style partial+reduction fusion ported onto our sparse
FlashMLA decode (20-35% precedent at exactly our work size), then
per-layer glue kernels on the 7-instruction template. Whole-model
fusion is explicitly not worth its maintenance cost (the authors say
so). Next step gated on glm52_kernel_bench baseline numbers.
…ist + CUDA graph (-42%/layer)

Measured on H100 (bs=1, glm52_kernel_bench, parity-verified vs the
alloc-heavy forward): as-is 288us -> 0a arena 218 -> 0b tile-schedule
hoist 190 -> 0c graph 166us; re-verified natively on a cuda-compat
575.57.08 host at 158us (-44%).

0a: Glm52MlaDecodeScratch pre-allocates all ~20 intermediates (synchronous
cudaMalloc serializes the decode stream, so this drops GPU time not just host).
0b: the FlashMLA tile schedule is data-independent (batch_size + num_sm_parts
only) — computed once in scratch::new, dropped from the per-token path.
0c: measure_forward_graph captures the pure-kernel forward via
CudaGraphState::run_or_capture; replay removes inter-kernel GPU bubbles.

Records the measured non-result: num_sm_parts tuning (isolated flashmla
1.68x at 16 vs default 132, output bitwise-identical) does NOT carry through
the graphed forward — the isolated win was the metadata kernel 0b already
hoists, and 16 splits underutilize the GPU. Fixes the kernel_bench proj
helper (fn, not closure).
@n-WN n-WN force-pushed the feat/glm52-kernel-bench branch 2 times, most recently from bdfba03 to 9e80cfe Compare July 3, 2026 16:53
Codex flagged that forward_into's length-only guard lets a scratch built
for one num_sm_parts run under a different contract — the pre-computed
tile schedule is only meaningful under the exact batch_size+num_sm_parts
it was generated with, so a mismatch could corrupt the latent output
instead of failing. The contract now lives in Glm52MlaDecodeScratch and
forward_into takes no contract parameter: a mismatch is unrepresentable.

Also from the adversarial pass: default_num_sm_parts() reports the device
query (not a --sm-parts override), so the sweep's '*' label and the
cross-split diff baseline stay honest under overrides; graph capture warms
the kernel sequence un-captured first (capture aborts on lazy first-touch
allocation); the sweep dedups when the default collides with a fixed step.
@n-WN n-WN force-pushed the feat/glm52-kernel-bench branch from 9e80cfe to c5173c2 Compare July 3, 2026 16:57
@n-WN n-WN changed the title perf(glm52): single-layer decode kernel bench + step-0 optimizations (arena, schedule hoist, CUDA graph — −42%/layer) perf(glm52): MLA decode arena + CUDA graph capture on top of #535 (−36%/layer) + kernel bench Jul 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant