From 291aa33f1be9f332387af1107d671e68aefb09bd Mon Sep 17 00:00:00 2001 From: n-WN <30841158+n-WN@users.noreply.github.com> Date: Thu, 2 Jul 2026 15:06:12 +0800 Subject: [PATCH 1/8] WIP glm52 kernel bench --- openinfer-glm52/Cargo.toml | 5 + openinfer-glm52/src/bin/glm52_kernel_bench.rs | 78 +++++ openinfer-glm52/src/kernel_bench.rs | 329 ++++++++++++++++++ openinfer-glm52/src/lib.rs | 2 + openinfer-glm52/src/mla_decode.rs | 16 + 5 files changed, 430 insertions(+) create mode 100644 openinfer-glm52/src/bin/glm52_kernel_bench.rs create mode 100644 openinfer-glm52/src/kernel_bench.rs diff --git a/openinfer-glm52/Cargo.toml b/openinfer-glm52/Cargo.toml index 09cdafe6..23630c33 100644 --- a/openinfer-glm52/Cargo.toml +++ b/openinfer-glm52/Cargo.toml @@ -10,6 +10,11 @@ autobins = false name = "glm52_load_weights" path = "src/bin/glm52_load_weights.rs" +[[bin]] +name = "glm52_kernel_bench" +path = "src/bin/glm52_kernel_bench.rs" +required-features = ["glm52"] + [features] default = [] glm52 = ["dep:openinfer-kernels", "openinfer-kernels/glm52"] diff --git a/openinfer-glm52/src/bin/glm52_kernel_bench.rs b/openinfer-glm52/src/bin/glm52_kernel_bench.rs new file mode 100644 index 00000000..8878b2e9 --- /dev/null +++ b/openinfer-glm52/src/bin/glm52_kernel_bench.rs @@ -0,0 +1,78 @@ +//! GLM5.2 single-layer MLA decode microbench (bs=1, synthetic weights). +//! +//! Usage: +//! cargo run --release -p openinfer-glm52 --features glm52 \ +//! --bin glm52_kernel_bench -- [--contexts 512,2048] [--iters 64] + +use anyhow::{Result, bail}; +use openinfer_glm52::kernel_bench::Glm52MlaDecodeBench; + +struct Args { + contexts: Vec, + iters: u64, +} + +fn parse_args(mut argv: impl Iterator) -> Result { + let mut args = Args { + contexts: vec![512, 2048], + iters: 64, + }; + while let Some(flag) = argv.next() { + let mut value = || { + argv.next() + .ok_or_else(|| anyhow::anyhow!("{flag} needs a value")) + }; + match flag.as_str() { + "--contexts" => { + args.contexts = value()? + .split(',') + .map(|v| v.trim().parse::()) + .collect::>()?; + } + "--iters" => args.iters = value()?.parse()?, + other => bail!("unknown flag `{other}` (supported: --contexts, --iters)"), + } + } + if args.contexts.is_empty() || args.iters == 0 { + bail!("--contexts must be non-empty and --iters positive"); + } + Ok(args) +} + +fn main() -> Result<()> { + let args = parse_args(std::env::args().skip(1))?; + for &context in &args.contexts { + let mut bench = Glm52MlaDecodeBench::new(context)?; + let (gpu, wall) = bench.measure_forward(args.iters)?; + let per = |d: std::time::Duration| d.as_secs_f64() * 1.0e6 / args.iters as f64; + println!("== context {context} (iters {}) ==", args.iters); + println!( + "layer forward gpu {:>9.1}us wall {:>9.1}us host-side gap {:>9.1}us", + per(gpu), + per(wall), + per(wall) - per(gpu) + ); + for proj in ["q_a", "q_b", "kv_a", "o_proj"] { + let d = bench.measure_projection(proj, args.iters)?; + println!( + "fp8_linear {proj:<8} wall {:>9.1}us (alloc chain included)", + per(d) + ); + } + let d = bench.measure_assembly_family(args.iters)?; + println!( + "assembly family wall {:>9.1}us (assemble+quant+pack, buffers reused)", + per(d) + ); + let d = bench.measure_flashmla(args.iters)?; + println!( + "flashmla sparse wall {:>9.1}us (metadata+decode, buffers reused)", + per(d) + ); + let projected_token = per(wall) * 75.0 / 1000.0; + println!( + "-> projected 75 MoE-layer attention share: {projected_token:.2} ms/token (as-is)\n" + ); + } + Ok(()) +} diff --git a/openinfer-glm52/src/kernel_bench.rs b/openinfer-glm52/src/kernel_bench.rs new file mode 100644 index 00000000..b69d85a1 --- /dev/null +++ b/openinfer-glm52/src/kernel_bench.rs @@ -0,0 +1,329 @@ +//! Single-layer GLM5.2 MLA decode microbenchmarks (bs=1, synthetic weights). +//! +//! The load-only bring-up composed the oracle-validated ops correctness-first: +//! every intermediate is a fresh `alloc_zeros` and every stage its own launch. +//! Before tuning any kernel, this bench measures what that costs on one layer — +//! the whole forward, each stage in isolation, and the allocation share — so +//! the optimization order comes from numbers, not guesses. Weights are +//! synthetic (constant fp8 bytes, unit scales): fp8 GEMMs do no zero-skipping, +//! so latency matches real checkpoints without needing one on disk. +//! +//! Follows the qwen3 `kernel_bench` convention: a `pub` module the +//! feature-gated bench bin drives; model facts live in `config.rs`. + +use std::time::{Duration, Instant}; + +use anyhow::Result; +use cudarc::driver::{CudaEvent, CudaSlice, sys}; +use half::bf16; + +use openinfer_kernels::ops::{ + GLM52_FLASHMLA_SPARSE_PAGE_SIZE, GLM52_FLASHMLA_SPARSE_TOPK, Glm52FlashMlaSparseDecode, + Glm52MoeQuantShape, glm52_flashmla_sparse_decode_launch, + glm52_flashmla_sparse_decode_metadata_launch, glm52_flashmla_sparse_decode_num_sm_parts, + glm52_fp8_per_token_group_quant_bf16_launch, glm52_mla_cache_pack_launch, + glm52_mla_query_assemble_launch, +}; +use openinfer_kernels::tensor::{DeviceContext, DeviceVec}; + +use crate::fp8::{FP8_BLOCK, Glm52ProjBytes, ProjWeight, fp8_linear}; +use crate::mla_decode::{Glm52MlaLayerWeights, glm52_mla_decode_forward}; + +const HEADS: usize = 64; +const HIDDEN: usize = 6144; +const Q_LORA: usize = 2048; +const Q_HEAD: usize = 256; +const KV_LORA: usize = 512; +const KV_A_OUT: usize = 576; +const V_HEAD: usize = 256; +const KV_B_ROWS_PER_HEAD: usize = 448; +const QUERY_DIM: usize = KV_LORA + 64; +const ROPE_HALF: usize = 32; +const CACHE_BYTES_PER_TOKEN: usize = 656; + +/// Constant-fill fp8 projection bytes at `[n, k]` with unit block scales. +/// 0x38 is e4m3 1.0 — finite, no NaN patterns, and latency-equivalent to +/// checkpoint bytes for the blockscale GEMM. +fn synth_proj(n: usize, k: usize) -> (Vec, Vec) { + let weight = vec![0x38u8; n * k]; + let scale_elems = n.div_ceil(FP8_BLOCK) * k.div_ceil(FP8_BLOCK); + let scale: Vec = (0..scale_elems) + .flat_map(|_| 1.0f32.to_le_bytes()) + .collect(); + (weight, scale) +} + +fn bf16_ones_bytes(len: usize) -> Vec { + (0..len) + .flat_map(|_| bf16::from_f32(1.0).to_le_bytes()) + .collect() +} + +/// One synthetic MLA layer plus every forward input, device-resident. +pub struct Glm52MlaDecodeBench { + pub ctx: DeviceContext, + weights: Glm52MlaLayerWeights, + hidden: CudaSlice, + cos: CudaSlice, + sin: CudaSlice, + cache: CudaSlice, + topk: CudaSlice, + contract: Glm52FlashMlaSparseDecode, + position: usize, + start: CudaEvent, + end: CudaEvent, +} + +impl Glm52MlaDecodeBench { + /// `context_len` is the attended context; topk indices cover + /// `min(context_len, 2048)` real slots, -1-padded to the fixed 2048. + pub fn new(context_len: usize) -> Result { + anyhow::ensure!(context_len > 0, "context_len must be positive"); + let ctx = DeviceContext::new()?; + + let (qa_w, qa_s) = synth_proj(Q_LORA, HIDDEN); + let (qb_w, qb_s) = synth_proj(HEADS * Q_HEAD, Q_LORA); + let (kva_w, kva_s) = synth_proj(KV_A_OUT, HIDDEN); + let (kvb_w, kvb_s) = synth_proj(HEADS * KV_B_ROWS_PER_HEAD, KV_LORA); + let (o_w, o_s) = synth_proj(HIDDEN, HEADS * V_HEAD); + let ln = bf16_ones_bytes(Q_LORA.max(KV_LORA)); + let proj = |w: &[u8], s: &[u8], n: usize, k: usize| Glm52ProjBytes { + weight: w, + scale: s, + n, + k, + }; + let weights = Glm52MlaLayerWeights::from_host( + &ctx, + &proj(&qa_w, &qa_s, Q_LORA, HIDDEN), + &ln[..Q_LORA * 2], + &proj(&qb_w, &qb_s, HEADS * Q_HEAD, Q_LORA), + &proj(&kva_w, &kva_s, KV_A_OUT, HIDDEN), + &ln[..KV_LORA * 2], + &proj(&kvb_w, &kvb_s, HEADS * KV_B_ROWS_PER_HEAD, KV_LORA), + &proj(&o_w, &o_s, HIDDEN, HEADS * V_HEAD), + )?; + + let hidden = ctx + .stream + .clone_htod(&vec![bf16::from_f32(0.01); HIDDEN])?; + let rope: Vec = (0..ROPE_HALF) + .map(|i| bf16::from_f32(((i as f32) * 0.1).cos())) + .collect(); + let cos = ctx.stream.clone_htod(&rope)?; + let sin = ctx.stream.clone_htod(&rope)?; + + let position = context_len - 1; + let num_blocks = context_len.div_ceil(GLM52_FLASHMLA_SPARSE_PAGE_SIZE); + let cache = ctx + .stream + .alloc_zeros(num_blocks * GLM52_FLASHMLA_SPARSE_PAGE_SIZE * CACHE_BYTES_PER_TOKEN)?; + let real = context_len.min(GLM52_FLASHMLA_SPARSE_TOPK); + let topk_host: Vec = (0..GLM52_FLASHMLA_SPARSE_TOPK) + .map(|i| if i < real { i as i32 } else { -1 }) + .collect(); + let topk = ctx.stream.clone_htod(&topk_host)?; + let contract = Glm52FlashMlaSparseDecode { + batch_size: 1, + num_blocks, + topk: GLM52_FLASHMLA_SPARSE_TOPK, + num_sm_parts: glm52_flashmla_sparse_decode_num_sm_parts()?, + sm_scale: 1.0 / (QUERY_DIM as f32).sqrt(), + }; + contract.validate()?; + + let start = ctx + .ctx + .new_event(Some(sys::CUevent_flags::CU_EVENT_DEFAULT))?; + let end = ctx + .ctx + .new_event(Some(sys::CUevent_flags::CU_EVENT_DEFAULT))?; + let bench = Self { + ctx, + weights, + hidden, + cos, + sin, + cache, + topk, + contract, + position, + start, + end, + }; + bench.ctx.sync()?; + Ok(bench) + } + + fn forward_once(&mut self) -> Result<()> { + let _o = glm52_mla_decode_forward( + &self.ctx, + &self.weights, + &self.hidden, + &self.cos, + &self.sin, + &mut self.cache, + self.position, + &self.topk, + self.contract, + )?; + Ok(()) + } + + /// Whole-layer forward: GPU (event) time and wall time per iteration. The + /// gap between them is host-side cost — dominated by the per-call + /// `alloc_zeros` chain, which serializes against the device. + pub fn measure_forward(&mut self, iters: u64) -> Result<(Duration, Duration)> { + self.forward_once()?; + self.ctx.sync()?; + let wall_start = Instant::now(); + let mut gpu_ms = 0.0f64; + for _ in 0..iters { + self.start.record(&self.ctx.stream)?; + self.forward_once()?; + self.end.record(&self.ctx.stream)?; + gpu_ms += f64::from(self.start.elapsed_ms(&self.end)?); + } + self.ctx.sync()?; + let wall = wall_start.elapsed(); + Ok((Duration::from_secs_f64(gpu_ms / 1_000.0), wall)) + } + + /// One fp8 projection in isolation (its own quant + layout + GEMM chain, + /// allocations included — exactly what the forward pays per projection). + pub fn measure_projection(&mut self, which: &str, iters: u64) -> Result { + let (w, input_len): (&ProjWeight, usize) = match which { + "q_a" => (self.weights.q_a(), HIDDEN), + "q_b" => (self.weights.q_b(), Q_LORA), + "kv_a" => (self.weights.kv_a(), HIDDEN), + "o_proj" => (self.weights.o_proj(), HEADS * V_HEAD), + other => anyhow::bail!("unknown projection `{other}`"), + }; + let input = self + .ctx + .stream + .clone_htod(&vec![bf16::from_f32(0.01); input_len])?; + let _ = fp8_linear(&self.ctx, w, &input)?; + self.ctx.sync()?; + let wall = Instant::now(); + for _ in 0..iters { + let _ = fp8_linear(&self.ctx, w, &input)?; + } + self.ctx.sync()?; + Ok(wall.elapsed()) + } + + /// The three assembly-family micro launches, measured together per + /// iteration (allocations excluded — buffers are reused here, which is + /// what a scratch-based forward would pay). + pub fn measure_assembly_family(&mut self, iters: u64) -> Result { + let ql_nope = self + .ctx + .stream + .clone_htod(&vec![bf16::from_f32(0.01); HEADS * KV_LORA])?; + let q_full = self + .ctx + .stream + .clone_htod(&vec![bf16::from_f32(0.01); HEADS * Q_HEAD])?; + let kv_c = self + .ctx + .stream + .clone_htod(&vec![bf16::from_f32(0.01); KV_LORA])?; + let k_pe = self + .ctx + .stream + .clone_htod(&vec![bf16::from_f32(0.01); 64])?; + let mut query = self.ctx.stream.alloc_zeros::(HEADS * QUERY_DIM)?; + let mut ckv_fp8 = self.ctx.stream.alloc_zeros::(KV_LORA)?; + let mut ckv_scales = self.ctx.stream.alloc_zeros::(KV_LORA / FP8_BLOCK)?; + + let mut launch = |bench: &mut Self, + query: &mut CudaSlice, + ckv_fp8: &mut CudaSlice, + ckv_scales: &mut CudaSlice| + -> Result<()> { + glm52_mla_query_assemble_launch( + &bench.ctx, &ql_nope, &q_full, 192, Q_HEAD, &bench.cos, &bench.sin, query, + )?; + glm52_fp8_per_token_group_quant_bf16_launch( + &bench.ctx, + Glm52MoeQuantShape { + rows: 1, + width: KV_LORA, + group_size: FP8_BLOCK, + }, + &kv_c, + ckv_fp8, + ckv_scales, + )?; + glm52_mla_cache_pack_launch( + &bench.ctx, + ckv_fp8, + ckv_scales, + &k_pe, + &bench.cos, + &bench.sin, + &mut bench.cache, + bench.position, + )?; + Ok(()) + }; + + launch(self, &mut query, &mut ckv_fp8, &mut ckv_scales)?; + self.ctx.sync()?; + let wall = Instant::now(); + for _ in 0..iters { + launch(self, &mut query, &mut ckv_fp8, &mut ckv_scales)?; + } + self.ctx.sync()?; + Ok(wall.elapsed()) + } + + /// FlashMLA sparse decode with pre-allocated metadata/output buffers + /// (the attention core a scratch-based forward would pay). + pub fn measure_flashmla(&mut self, iters: u64) -> Result { + let query = self + .ctx + .stream + .clone_htod(&vec![bf16::from_f32(0.01); HEADS * QUERY_DIM])?; + let c = self.contract; + let mut sched = self + .ctx + .stream + .alloc_zeros::(c.tile_scheduler_metadata_len())?; + let mut splits = self.ctx.stream.alloc_zeros::(c.num_splits_len())?; + let mut latent = self.ctx.stream.alloc_zeros::(c.latent_len())?; + let mut lse = self.ctx.stream.alloc_zeros::(c.lse_len())?; + let mut lse_accum = self.ctx.stream.alloc_zeros::(c.lse_accum_len())?; + let mut o_accum = self.ctx.stream.alloc_zeros::(c.o_accum_len())?; + + glm52_flashmla_sparse_decode_metadata_launch( + &self.ctx, + c.batch_size, + c.num_sm_parts, + &mut sched, + &mut splits, + )?; + glm52_flashmla_sparse_decode_launch( + &self.ctx, c, &query, &self.cache, &self.topk, &sched, &splits, &mut latent, &mut lse, + &mut lse_accum, &mut o_accum, + )?; + self.ctx.sync()?; + let wall = Instant::now(); + for _ in 0..iters { + glm52_flashmla_sparse_decode_metadata_launch( + &self.ctx, + c.batch_size, + c.num_sm_parts, + &mut sched, + &mut splits, + )?; + glm52_flashmla_sparse_decode_launch( + &self.ctx, c, &query, &self.cache, &self.topk, &sched, &splits, &mut latent, + &mut lse, &mut lse_accum, &mut o_accum, + )?; + } + self.ctx.sync()?; + Ok(wall.elapsed()) + } +} diff --git a/openinfer-glm52/src/lib.rs b/openinfer-glm52/src/lib.rs index 16d08487..cab03abe 100644 --- a/openinfer-glm52/src/lib.rs +++ b/openinfer-glm52/src/lib.rs @@ -16,6 +16,8 @@ mod dense; #[cfg(feature = "glm52")] mod fp8; #[cfg(feature = "glm52")] +pub mod kernel_bench; +#[cfg(feature = "glm52")] mod indexer; #[cfg(all(test, feature = "glm52"))] mod indexer_oracle_gate; diff --git a/openinfer-glm52/src/mla_decode.rs b/openinfer-glm52/src/mla_decode.rs index af056b64..8dd7efef 100644 --- a/openinfer-glm52/src/mla_decode.rs +++ b/openinfer-glm52/src/mla_decode.rs @@ -56,6 +56,22 @@ impl Glm52MlaLayerWeights { /// layernorm gammas, and host-dequant kv_b into the bf16 absorb factors /// W_UK = kv_b[:, :192, :], W_UV = kv_b[:, 192:, :]. #[allow(clippy::too_many_arguments)] + pub(crate) fn q_a(&self) -> &ProjWeight { + &self.q_a + } + + pub(crate) fn q_b(&self) -> &ProjWeight { + &self.q_b + } + + pub(crate) fn kv_a(&self) -> &ProjWeight { + &self.kv_a + } + + pub(crate) fn o_proj(&self) -> &ProjWeight { + &self.o_proj + } + pub(crate) fn from_host( ctx: &DeviceContext, q_a: &Glm52ProjBytes, From 6b308195336ad81730588d814b4b99623549b127 Mon Sep 17 00:00:00 2001 From: n-WN <30841158+n-WN@users.noreply.github.com> Date: Thu, 2 Jul 2026 22:50:26 +0800 Subject: [PATCH 2/8] =?UTF-8?q?docs(lessons):=20megakernel=20decode-latenc?= =?UTF-8?q?y=20research=20=E2=80=94=20what=20transfers=20here?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Synthesis of the no-bubbles Llama-1B megakernel, ThunderMLA, Mirage MPK, and the TP-Llama-70B kernel, read against openinfer's model lines. Key transfers: CUDA Graphs only eliminate launch cost and leave a measured ~50->78% HBM-bandwidth gap (bubbles + producer/consumer locality); the value ladder for GLM5.2 is arena+graph first (dwarfs fusion), then a ThunderMLA-style partial+reduction fusion ported onto our sparse FlashMLA decode (20-35% precedent at exactly our work size), then per-layer glue kernels on the 7-instruction template. Whole-model fusion is explicitly not worth its maintenance cost (the authors say so). Next step gated on glm52_kernel_bench baseline numbers. --- .../lessons/megakernels-for-decode-latency.md | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 docs/lessons/megakernels-for-decode-latency.md diff --git a/docs/lessons/megakernels-for-decode-latency.md b/docs/lessons/megakernels-for-decode-latency.md new file mode 100644 index 00000000..8a2f9cc3 --- /dev/null +++ b/docs/lessons/megakernels-for-decode-latency.md @@ -0,0 +1,24 @@ +# Megakernels for decode latency — what transfers to openinfer + +**TL;DR**: Megakernels (one persistent kernel running the whole forward via an on-GPU interpreter + counter sync + smem paging) buy their wins from three separable things — launch elimination, inter-kernel pipeline bubbles, producer→consumer locality. CUDA Graphs only buy the first (~2.1→1.3µs/launch), leaving a measured ~50%→78% HBM-bandwidth gap on bs=1 decode. For openinfer the value ladder is: (0) arenas + graphs where they don't exist yet (GLM5.2 — dwarfs everything else), (1) ThunderMLA-style fusion of attention partial+reduction pairs (20–35% precedent, ~250 LoC), (2) per-layer persistent "glue" kernels on the 7-instruction template, (3) whole-model fusion — explicitly not worth its maintenance cost today. + +## The four systems and their mechanisms + +- **Llama-1B megakernel** ([no-bubbles, Hazy Research](https://hazyresearch.stanford.edu/blog/2025-05-27-no-bubbles)): whole forward = one persistent kernel; per-SM instruction sequences scheduled host-side ahead of time and reused across passes; only **7 fused instruction types** (norm+QKV+RoPE, attention, O-proj+residual, norm+up-gate+SiLU, down-proj+residual, LM head). Sync = global counter array, dependents spin; the down-proj splits its input into 4 counter-guarded segments so consumers start early. Shared memory carved into 13×16KiB explicitly requested/released pages — a released page hands straight to the next instruction's weight prefetch. Results: <1ms/token Llama-1B bs=1 on H100, 2.5× vLLM, **78% of HBM bandwidth vs ~50% for kernel-per-op engines**. Post-fusion cost breakdown (600µs on B200): 250µs activation movement+sync, 200µs norm+matvec, 40µs warp sync — after launches die, *sync and activation movement dominate*. +- **ThunderMLA** ([Hazy](https://hazyresearch.stanford.edu/blog/2025-03-04-thundermla)): fuses FlashMLA's two kernels (split-KV partials + reduction) into one persistent kernel driven by an instruction tensor; static host schedulers (heap-based, plus makespan-backwards for ~10% extra). **20–35% over FlashMLA**; gains shrink as per-launch work grows (+36% at B64/512, +7.6% at B132/4K). Public: ThunderKittens repo, `mla` branch, ~250 lines of device code. **Dense MLA only** — no paged KV, no sparse top-k, no fp8 KV. +- **Mirage MPK** ([Zhihao Jia](https://zhihaojia.medium.com/compiling-llms-into-a-megakernel-a-path-to-low-latency-inference-cf7840913c17)): compiler from a *PyTorch graph* to one megakernel; worker SMs + up to 4 scheduler SMs with event counters; 1–2µs task transitions. Static task graphs only (no MoE dynamism), tightly coupled to the Mirage stack — not embeddable from a Rust FFI engine today; useful as a design reference for the worker/scheduler split. +- **TP-Llama-70B** ([Hazy](https://hazyresearch.stanford.edu/blog/2025-09-28-tp-llama-main)): the interpreter model across 8×H100 with comm fused *inside* the kernel (async peer-memory stores, no NCCL); dynamic global work queue (+14.2% at bs 8192). Relevant later for EP/TP paths; the authors call the code unsupported and "sensitive to … being looked at the wrong way". + +## What this means per openinfer model line + +- **GLM5.2 (bring-up)**: the single-layer MLA decode currently issues ~18 launches and ~20 `cudaMalloc`s per layer per token (78 layers). Step 0 is an allocation arena + CUDA Graph capture — boring, and worth more than any fusion. Step 1 is the ThunderMLA pattern applied to our *sparse* FlashMLA decode: port the interpreter/instruction-tensor structure onto the vendored sparse kernel (drop-in is not possible — ThunderMLA is dense); our fixed top-k=2048 workload sits exactly in the small/medium-work regime where their fusion gains were largest. Step 2, if step 1 pays: one persistent kernel per layer fusing the glue (q/kv projections, norm, RoPE, cache-pack) on the 7-instruction template. +- **Qwen3-4B (already graphed)**: graphs killed launch overhead; the remaining megakernel upside is bubbles + locality, ceiling plausibly 1.3–1.5× if decode is bandwidth-bound (50%→78% measured elsewhere). Real but expensive; revisit only after cheaper wins (fused step-tail, sampling cost #483) are exhausted. +- **Scheduling flavor**: at bs=1, static host-side per-SM schedules (1B-style) are enough and far simpler than dynamic queues; cache the schedule like ThunderMLA does (1–2ms to build, amortized). + +## Pitfalls the authors state outright + +Instruction-set design and counter-sync debugging carry "tremendous complexity"; the TP kernel is explicitly unmaintained. None of the megakernel posts benchmark against CUDA Graphs except no-bubbles — treat vLLM/SGLang comparisons in MPK/TP posts as launch-overhead-inclusive. Any adoption here should ship as one model line's experiment with an A/B against that line's graphed baseline, never as shared-layer infrastructure first. + +## Next + +Blocked on the GLM5.2 single-layer baseline numbers (glm52_kernel_bench) to size step 0 vs step 1; then decide whether a sparse-ThunderMLA experiment is worth a design issue. From 14d438648ef738f6282e448ce1348d3ef2da356a Mon Sep 17 00:00:00 2001 From: n-WN <30841158+n-WN@users.noreply.github.com> Date: Thu, 2 Jul 2026 22:53:54 +0800 Subject: [PATCH 3/8] WIP glm52 scratch forward + A/B bench --- openinfer-glm52/src/bin/glm52_kernel_bench.rs | 11 + openinfer-glm52/src/fp8.rs | 92 +++++++ openinfer-glm52/src/kernel_bench.rs | 46 +++- openinfer-glm52/src/mla_decode.rs | 224 +++++++++++++++++- tools/review-watch.sh | 84 +++++++ 5 files changed, 455 insertions(+), 2 deletions(-) create mode 100755 tools/review-watch.sh diff --git a/openinfer-glm52/src/bin/glm52_kernel_bench.rs b/openinfer-glm52/src/bin/glm52_kernel_bench.rs index 8878b2e9..1d493df4 100644 --- a/openinfer-glm52/src/bin/glm52_kernel_bench.rs +++ b/openinfer-glm52/src/bin/glm52_kernel_bench.rs @@ -52,6 +52,17 @@ fn main() -> Result<()> { per(wall), per(wall) - per(gpu) ); + let (gpu_s, wall_s) = bench.measure_forward_scratch(args.iters)?; + println!( + "layer fwd scratch gpu {:>9.1}us wall {:>9.1}us host-side gap {:>9.1}us", + per(gpu_s), + per(wall_s), + per(wall_s) - per(gpu_s) + ); + println!( + "-> alloc bill (as-is wall - scratch wall): {:>9.1}us/layer", + per(wall) - per(wall_s) + ); for proj in ["q_a", "q_b", "kv_a", "o_proj"] { let d = bench.measure_projection(proj, args.iters)?; println!( diff --git a/openinfer-glm52/src/fp8.rs b/openinfer-glm52/src/fp8.rs index 12e612d9..1cfd2cbf 100644 --- a/openinfer-glm52/src/fp8.rs +++ b/openinfer-glm52/src/fp8.rs @@ -114,6 +114,98 @@ impl ProjWeight { } } +/// Reusable quant-side buffers for [`fp8_linear_into`], sized once for the +/// largest `k` any projection uses. The per-call `fp8_linear` allocates these +/// (plus its output) fresh on every projection — ~4 synchronous `cudaMalloc`s +/// per projection per token — which the scratch variant exists to eliminate. +pub(crate) struct Fp8LinearScratch { + a_fp8: CudaSlice, + a_scale_plain: CudaSlice, + a_scale_tma: CudaSlice, + max_k: usize, +} + +impl Fp8LinearScratch { + pub(crate) fn new(ctx: &DeviceContext, max_k: usize) -> Result { + ensure!( + max_k > 0 && max_k.is_multiple_of(FP8_BLOCK), + "GLM5.2 fp8 scratch max_k {max_k} must be a positive multiple of {FP8_BLOCK}" + ); + let layout = Glm52DeepGemmScaleLayout::f32(1, max_k / FP8_BLOCK); + Ok(Self { + a_fp8: ctx.stream.alloc_zeros::(max_k)?, + a_scale_plain: ctx.stream.alloc_zeros::(max_k / FP8_BLOCK)?, + a_scale_tma: ctx.stream.alloc_zeros::(layout.output_len()?)?, + max_k, + }) + } +} + +/// [`fp8_linear`] with every intermediate (and the output) caller-provided: +/// the same quant -> TMA-relayout -> blockscale-GEMM chain, zero allocations. +/// `out` must hold at least `w.n` elements; only the first `w.n` are written. +pub(crate) fn fp8_linear_into( + ctx: &DeviceContext, + w: &ProjWeight, + input: &CudaSlice, + scratch: &mut Fp8LinearScratch, + out: &mut CudaSlice, +) -> Result<()> { + ensure!( + input.len() >= w.k, + "GLM5.2 fp8_linear_into input {} < k {}", + input.len(), + w.k + ); + ensure!( + w.k <= scratch.max_k, + "GLM5.2 fp8_linear_into k {} > scratch max_k {}", + w.k, + scratch.max_k + ); + ensure!( + out.len() >= w.n, + "GLM5.2 fp8_linear_into out {} < n {}", + out.len(), + w.n + ); + let scale_cols = w.k / FP8_BLOCK; + glm52_fp8_per_token_group_quant_bf16_launch( + ctx, + Glm52MoeQuantShape { + rows: 1, + width: w.k, + group_size: FP8_BLOCK, + }, + input, + &mut scratch.a_fp8, + &mut scratch.a_scale_plain, + )?; + let layout = Glm52DeepGemmScaleLayout::f32(1, scale_cols); + glm52_deepgemm_mn_major_tma_aligned_f32_launch( + ctx, + layout, + &scratch.a_scale_plain, + &mut scratch.a_scale_tma, + )?; + glm52_trtllm_fp8_linear_launch( + ctx, + Glm52TrtllmFp8LinearContract { + m: 1, + n: w.n, + k: w.k, + weight_scale_rows: w.n.div_ceil(FP8_BLOCK), + weight_scale_cols: scale_cols, + activation_scale_cols: scale_cols, + }, + &scratch.a_fp8, + &scratch.a_scale_tma, + &w.weight, + &w.scale, + out, + ) +} + /// One fp8 projection (bs=1): quant the bf16 activation, then the prequant linear. /// Returns `[n]` bf16. pub(crate) fn fp8_linear( diff --git a/openinfer-glm52/src/kernel_bench.rs b/openinfer-glm52/src/kernel_bench.rs index b69d85a1..d492d933 100644 --- a/openinfer-glm52/src/kernel_bench.rs +++ b/openinfer-glm52/src/kernel_bench.rs @@ -27,7 +27,10 @@ use openinfer_kernels::ops::{ use openinfer_kernels::tensor::{DeviceContext, DeviceVec}; use crate::fp8::{FP8_BLOCK, Glm52ProjBytes, ProjWeight, fp8_linear}; -use crate::mla_decode::{Glm52MlaLayerWeights, glm52_mla_decode_forward}; +use crate::mla_decode::{ + Glm52MlaDecodeScratch, Glm52MlaLayerWeights, glm52_mla_decode_forward, + glm52_mla_decode_forward_into, +}; const HEADS: usize = 64; const HIDDEN: usize = 6144; @@ -189,6 +192,47 @@ impl Glm52MlaDecodeBench { Ok((Duration::from_secs_f64(gpu_ms / 1_000.0), wall)) } + /// Whole-layer forward through the zero-allocation scratch variant — + /// the as-is vs scratch delta is the per-layer cudaMalloc bill. + pub fn measure_forward_scratch(&mut self, iters: u64) -> Result<(Duration, Duration)> { + let mut scratch = Glm52MlaDecodeScratch::new(&self.ctx, self.contract)?; + glm52_mla_decode_forward_into( + &self.ctx, + &self.weights, + &self.hidden, + &self.cos, + &self.sin, + &mut self.cache, + self.position, + &self.topk, + self.contract, + &mut scratch, + )?; + self.ctx.sync()?; + let wall_start = Instant::now(); + let mut gpu_ms = 0.0f64; + for _ in 0..iters { + self.start.record(&self.ctx.stream)?; + glm52_mla_decode_forward_into( + &self.ctx, + &self.weights, + &self.hidden, + &self.cos, + &self.sin, + &mut self.cache, + self.position, + &self.topk, + self.contract, + &mut scratch, + )?; + self.end.record(&self.ctx.stream)?; + gpu_ms += f64::from(self.start.elapsed_ms(&self.end)?); + } + self.ctx.sync()?; + let wall = wall_start.elapsed(); + Ok((Duration::from_secs_f64(gpu_ms / 1_000.0), wall)) + } + /// One fp8 projection in isolation (its own quant + layout + GEMM chain, /// allocations included — exactly what the forward pays per projection). pub fn measure_projection(&mut self, which: &str, iters: u64) -> Result { diff --git a/openinfer-glm52/src/mla_decode.rs b/openinfer-glm52/src/mla_decode.rs index 8dd7efef..b74e866a 100644 --- a/openinfer-glm52/src/mla_decode.rs +++ b/openinfer-glm52/src/mla_decode.rs @@ -24,7 +24,10 @@ use openinfer_kernels::ops::{ }; use openinfer_kernels::tensor::{DeviceContext, DeviceVec}; -use crate::fp8::{FP8_BLOCK, Glm52ProjBytes, ProjWeight, bytes_to_f32, e4m3_to_f32, fp8_linear}; +use crate::fp8::{ + FP8_BLOCK, Fp8LinearScratch, Glm52ProjBytes, ProjWeight, bytes_to_f32, e4m3_to_f32, + fp8_linear, fp8_linear_into, +}; const HEADS: usize = 64; const HIDDEN: usize = 6144; @@ -448,3 +451,222 @@ pub(crate) fn glm52_mla_attend( let o = fp8_linear(ctx, &w.o_proj, &v)?; // [6144] Ok(o) } + +/// Every intermediate of one MLA decode forward, allocated once. The plain +/// `glm52_mla_decode_forward` allocates ~20 device buffers per call — each a +/// synchronous `cudaMalloc` — which is the dominant host-side cost per layer +/// per token; this scratch plus `glm52_mla_decode_forward_into` is the +/// zero-allocation variant (same ops, same math, buffers reused). +pub(crate) struct Glm52MlaDecodeScratch { + fp8: Fp8LinearScratch, + q_a: DeviceVec, + q_resid: DeviceVec, + q_full: CudaSlice, + ckv: CudaSlice, + compressed_kv: DeviceVec, + kv_c: DeviceVec, + k_pe: CudaSlice, + ql_nope: CudaSlice, + query: CudaSlice, + ckv_fp8: CudaSlice, + ckv_scales: CudaSlice, + sched: CudaSlice, + splits: CudaSlice, + latent: CudaSlice, + lse: CudaSlice, + lse_accum: CudaSlice, + o_accum: CudaSlice, + v: CudaSlice, + o: CudaSlice, +} + +impl Glm52MlaDecodeScratch { + pub(crate) fn new( + ctx: &DeviceContext, + contract: Glm52FlashMlaSparseDecode, + ) -> Result { + Ok(Self { + fp8: Fp8LinearScratch::new(ctx, HEADS * V_HEAD)?, + q_a: DeviceVec::zeros(ctx, Q_LORA)?, + q_resid: DeviceVec::zeros(ctx, Q_LORA)?, + q_full: ctx.stream.alloc_zeros::(HEADS * Q_HEAD)?, + ckv: ctx.stream.alloc_zeros::(KV_A_OUT)?, + compressed_kv: DeviceVec::zeros(ctx, KV_LORA)?, + kv_c: DeviceVec::zeros(ctx, KV_LORA)?, + k_pe: ctx.stream.alloc_zeros::(ROPE_DIM)?, + ql_nope: ctx.stream.alloc_zeros::(HEADS * KV_LORA)?, + query: ctx.stream.alloc_zeros::(HEADS * QUERY_DIM)?, + ckv_fp8: ctx.stream.alloc_zeros::(KV_LORA)?, + ckv_scales: ctx.stream.alloc_zeros::(KV_LORA / FP8_BLOCK)?, + sched: ctx + .stream + .alloc_zeros::(contract.tile_scheduler_metadata_len())?, + splits: ctx.stream.alloc_zeros::(contract.num_splits_len())?, + latent: ctx.stream.alloc_zeros::(contract.latent_len())?, + lse: ctx.stream.alloc_zeros::(contract.lse_len())?, + lse_accum: ctx.stream.alloc_zeros::(contract.lse_accum_len())?, + o_accum: ctx.stream.alloc_zeros::(contract.o_accum_len())?, + v: ctx.stream.alloc_zeros::(HEADS * V_HEAD)?, + o: ctx.stream.alloc_zeros::(HIDDEN)?, + }) + } + + /// The layer output written by the last `forward_into`. + pub(crate) fn output(&self) -> &CudaSlice { + &self.o + } +} + +/// [`glm52_mla_decode_forward`] with all intermediates in `scratch`: the same +/// op sequence with zero per-call allocations. The scratch's FlashMLA buffers +/// are sized by the `contract` it was built with, so the same contract must be +/// passed here. +#[allow(clippy::too_many_arguments)] +pub(crate) fn glm52_mla_decode_forward_into( + ctx: &DeviceContext, + w: &Glm52MlaLayerWeights, + hidden: &CudaSlice, + cos: &CudaSlice, + sin: &CudaSlice, + cache: &mut CudaSlice, + position: usize, + topk: &CudaSlice, + contract: Glm52FlashMlaSparseDecode, + scratch: &mut Glm52MlaDecodeScratch, +) -> Result<()> { + ensure!(hidden.len() >= HIDDEN, "GLM5.2 MLA hidden too small"); + ensure!( + position < contract.num_blocks * GLM52_FLASHMLA_SPARSE_PAGE_SIZE, + "GLM5.2 MLA position {position} outside paged cache ({} blocks x {GLM52_FLASHMLA_SPARSE_PAGE_SIZE})", + contract.num_blocks + ); + ensure!( + scratch.sched.len() >= contract.tile_scheduler_metadata_len() + && scratch.latent.len() >= contract.latent_len(), + "GLM5.2 MLA scratch was built for a different FlashMLA contract" + ); + + // ---- front projections ---- + fp8_linear_into(ctx, &w.q_a, hidden, &mut scratch.fp8, &mut scratch.q_a.data)?; + rms_norm_into(ctx, &scratch.q_a, &w.q_a_ln, RMS_EPS, &mut scratch.q_resid)?; + fp8_linear_into( + ctx, + &w.q_b, + &scratch.q_resid.data, + &mut scratch.fp8, + &mut scratch.q_full, + )?; + fp8_linear_into(ctx, &w.kv_a, hidden, &mut scratch.fp8, &mut scratch.ckv)?; + ctx.stream.memcpy_dtod( + &scratch.ckv.slice(0..KV_LORA), + &mut scratch.compressed_kv.data, + )?; + rms_norm_into( + ctx, + &scratch.compressed_kv, + &w.kv_a_ln, + RMS_EPS, + &mut scratch.kv_c, + )?; + ctx.stream + .memcpy_dtod(&scratch.ckv.slice(KV_LORA..KV_LORA + ROPE_DIM), &mut scratch.k_pe)?; + + // ---- absorb: ql_nope[64,512] = q_pass @ W_UK ---- + gemm_strided_batched_bf16( + ctx, + false, + false, + KV_LORA, + 1, + QK_NOPE, + &w.w_uk, + KV_LORA, + QK_NOPE * KV_LORA, + &scratch.q_full, + QK_NOPE, + Q_HEAD, + &mut scratch.ql_nope, + KV_LORA, + KV_LORA, + HEADS, + )?; + + // ---- assemble query ---- + glm52_mla_query_assemble_launch( + ctx, + &scratch.ql_nope, + &scratch.q_full, + QK_NOPE, + Q_HEAD, + cos, + sin, + &mut scratch.query, + )?; + + // ---- pack the new token into the cache ---- + glm52_fp8_per_token_group_quant_bf16_launch( + ctx, + Glm52MoeQuantShape { + rows: 1, + width: KV_LORA, + group_size: FP8_BLOCK, + }, + &scratch.kv_c.data, + &mut scratch.ckv_fp8, + &mut scratch.ckv_scales, + )?; + glm52_mla_cache_pack_launch( + ctx, + &scratch.ckv_fp8, + &scratch.ckv_scales, + &scratch.k_pe, + cos, + sin, + cache, + position, + )?; + + // ---- FlashMLA sparse decode ---- + glm52_flashmla_sparse_decode_metadata_launch( + ctx, + contract.batch_size, + contract.num_sm_parts, + &mut scratch.sched, + &mut scratch.splits, + )?; + glm52_flashmla_sparse_decode_launch( + ctx, + contract, + &scratch.query, + cache, + topk, + &scratch.sched, + &scratch.splits, + &mut scratch.latent, + &mut scratch.lse, + &mut scratch.lse_accum, + &mut scratch.o_accum, + )?; + + // ---- back: v = latent @ W_UV, then o_proj ---- + gemm_strided_batched_bf16( + ctx, + true, + false, + V_HEAD, + 1, + KV_LORA, + &w.w_uv, + KV_LORA, + V_HEAD * KV_LORA, + &scratch.latent, + KV_LORA, + KV_LORA, + &mut scratch.v, + V_HEAD, + V_HEAD, + HEADS, + )?; + fp8_linear_into(ctx, &w.o_proj, &scratch.v, &mut scratch.fp8, &mut scratch.o)?; + Ok(()) +} diff --git a/tools/review-watch.sh b/tools/review-watch.sh new file mode 100755 index 00000000..fb606db4 --- /dev/null +++ b/tools/review-watch.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash +# AI-review watcher for our open openinfer PRs. +# +# Polls each PR for NEW review comments from bots (Codex) or the maintainer +# (xiaguan) since a saved baseline. When something new appears it does NOT +# auto-apply — it prints the new comments and exits non-zero so the caller +# (Claude) wakes up and runs a subagent to judge whether the comment is +# actually right before touching code. Priority when they conflict: +# xiaguan > Codex; both can be wrong, so the subagent verifies against the +# code, never rubber-stamps. +# +# Usage: +# tools/review-watch.sh init # snapshot current comment counts +# tools/review-watch.sh check # one-shot: exit 10 if new activity +# tools/review-watch.sh watch [secs] # loop until new activity (default 300s) + +set -uo pipefail + +REPO=openinfer-project/openinfer +PRS=(485 491) +STATE_DIR="${TMPDIR:-/tmp}/openinfer-review-watch" +mkdir -p "$STATE_DIR" + +# " " fingerprint per PR. +fingerprint() { + local pr=$1 + local bot xg + bot=$(gh api "repos/$REPO/pulls/$pr/comments" \ + --jq '[.[] | select(.user.login != "n-WN")] | length' 2>/dev/null || echo "?") + xg=$(gh pr view "$pr" -R "$REPO" --json comments \ + --jq '[.comments[] | select(.author.login == "xiaguan")] | length' 2>/dev/null || echo "?") + echo "$bot $xg" +} + +new_comments() { + local pr=$1 + echo "=== PR #$pr new review activity ===" + gh api "repos/$REPO/pulls/$pr/comments" \ + --jq '.[] | select(.user.login != "n-WN") | "[\(.user.login)] \(.path):\(.line)\n\(.body)\n"' 2>/dev/null + gh pr view "$pr" -R "$REPO" --json comments \ + --jq '.comments[] | select(.author.login == "xiaguan") | "[xiaguan issue-comment] \(.body)\n"' 2>/dev/null +} + +case "${1:-check}" in + init) + for pr in "${PRS[@]}"; do fingerprint "$pr" > "$STATE_DIR/pr-$pr"; done + echo "baseline saved for PRs: ${PRS[*]}" + ;; + check) + changed=0 + for pr in "${PRS[@]}"; do + cur=$(fingerprint "$pr") + old=$(cat "$STATE_DIR/pr-$pr" 2>/dev/null || echo "") + if [ "$cur" != "$old" ]; then + changed=1 + new_comments "$pr" + echo "$cur" > "$STATE_DIR/pr-$pr" + fi + done + [ "$changed" = 1 ] && exit 10 || echo "no new review activity" + ;; + watch) + interval="${2:-300}" + for pr in "${PRS[@]}"; do + [ -f "$STATE_DIR/pr-$pr" ] || fingerprint "$pr" > "$STATE_DIR/pr-$pr" + done + while true; do + for pr in "${PRS[@]}"; do + cur=$(fingerprint "$pr") + old=$(cat "$STATE_DIR/pr-$pr" 2>/dev/null || echo "") + if [ "$cur" != "$old" ]; then + new_comments "$pr" + echo "$cur" > "$STATE_DIR/pr-$pr" + exit 10 + fi + done + sleep "$interval" + done + ;; + *) + echo "usage: $0 {init|check|watch [secs]}" >&2 + exit 2 + ;; +esac From f54a32c06a25a4cbb539ea66df608350e344d26a Mon Sep 17 00:00:00 2001 From: n-WN <30841158+n-WN@users.noreply.github.com> Date: Thu, 2 Jul 2026 22:58:35 +0800 Subject: [PATCH 4/8] feat(glm52): zero-alloc scratch MLA decode forward + A/B parity bench --- openinfer-glm52/src/bin/glm52_kernel_bench.rs | 1 + openinfer-glm52/src/kernel_bench.rs | 44 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/openinfer-glm52/src/bin/glm52_kernel_bench.rs b/openinfer-glm52/src/bin/glm52_kernel_bench.rs index 1d493df4..107f48e9 100644 --- a/openinfer-glm52/src/bin/glm52_kernel_bench.rs +++ b/openinfer-glm52/src/bin/glm52_kernel_bench.rs @@ -43,6 +43,7 @@ fn main() -> Result<()> { let args = parse_args(std::env::args().skip(1))?; for &context in &args.contexts { let mut bench = Glm52MlaDecodeBench::new(context)?; + bench.verify_scratch_parity()?; let (gpu, wall) = bench.measure_forward(args.iters)?; let per = |d: std::time::Duration| d.as_secs_f64() * 1.0e6 / args.iters as f64; println!("== context {context} (iters {}) ==", args.iters); diff --git a/openinfer-glm52/src/kernel_bench.rs b/openinfer-glm52/src/kernel_bench.rs index d492d933..495673c9 100644 --- a/openinfer-glm52/src/kernel_bench.rs +++ b/openinfer-glm52/src/kernel_bench.rs @@ -233,6 +233,50 @@ impl Glm52MlaDecodeBench { Ok((Duration::from_secs_f64(gpu_ms / 1_000.0), wall)) } + /// Bitwise parity between the as-is forward and the scratch forward. + /// Weights and inputs are deterministic, and the scratch path runs the + /// exact same op sequence, so any mismatch is a real bug, not noise. + pub fn verify_scratch_parity(&mut self) -> Result<()> { + let expected = glm52_mla_decode_forward( + &self.ctx, + &self.weights, + &self.hidden, + &self.cos, + &self.sin, + &mut self.cache, + self.position, + &self.topk, + self.contract, + )?; + let expected = self.ctx.stream.clone_dtoh(&expected)?; + let mut scratch = Glm52MlaDecodeScratch::new(&self.ctx, self.contract)?; + glm52_mla_decode_forward_into( + &self.ctx, + &self.weights, + &self.hidden, + &self.cos, + &self.sin, + &mut self.cache, + self.position, + &self.topk, + self.contract, + &mut scratch, + )?; + let actual = self.ctx.stream.clone_dtoh(scratch.output())?; + self.ctx.sync()?; + let mismatches = expected + .iter() + .zip(&actual) + .filter(|(e, a)| e.to_bits() != a.to_bits()) + .count(); + anyhow::ensure!( + mismatches == 0, + "scratch forward diverges from as-is forward: {mismatches}/{} elements differ", + expected.len() + ); + Ok(()) + } + /// One fp8 projection in isolation (its own quant + layout + GEMM chain, /// allocations included — exactly what the forward pays per projection). pub fn measure_projection(&mut self, which: &str, iters: u64) -> Result { From 6c4b1bdb1aff9493bce2b27eaad30c5bab35cce1 Mon Sep 17 00:00:00 2001 From: n-WN <30841158+n-WN@users.noreply.github.com> Date: Thu, 2 Jul 2026 23:10:53 +0800 Subject: [PATCH 5/8] style(glm52): rustfmt + review fixups (unused import, doc/attr adjacency) --- openinfer-glm52/src/kernel_bench.rs | 32 +++++++++++++++++++++-------- openinfer-glm52/src/mla_decode.rs | 23 ++++++++++----------- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/openinfer-glm52/src/kernel_bench.rs b/openinfer-glm52/src/kernel_bench.rs index 495673c9..dd15bb46 100644 --- a/openinfer-glm52/src/kernel_bench.rs +++ b/openinfer-glm52/src/kernel_bench.rs @@ -24,7 +24,7 @@ use openinfer_kernels::ops::{ glm52_fp8_per_token_group_quant_bf16_launch, glm52_mla_cache_pack_launch, glm52_mla_query_assemble_launch, }; -use openinfer_kernels::tensor::{DeviceContext, DeviceVec}; +use openinfer_kernels::tensor::DeviceContext; use crate::fp8::{FP8_BLOCK, Glm52ProjBytes, ProjWeight, fp8_linear}; use crate::mla_decode::{ @@ -107,9 +107,7 @@ impl Glm52MlaDecodeBench { &proj(&o_w, &o_s, HIDDEN, HEADS * V_HEAD), )?; - let hidden = ctx - .stream - .clone_htod(&vec![bf16::from_f32(0.01); HIDDEN])?; + let hidden = ctx.stream.clone_htod(&vec![bf16::from_f32(0.01); HIDDEN])?; let rope: Vec = (0..ROPE_HALF) .map(|i| bf16::from_f32(((i as f32) * 0.1).cos())) .collect(); @@ -393,8 +391,17 @@ impl Glm52MlaDecodeBench { &mut splits, )?; glm52_flashmla_sparse_decode_launch( - &self.ctx, c, &query, &self.cache, &self.topk, &sched, &splits, &mut latent, &mut lse, - &mut lse_accum, &mut o_accum, + &self.ctx, + c, + &query, + &self.cache, + &self.topk, + &sched, + &splits, + &mut latent, + &mut lse, + &mut lse_accum, + &mut o_accum, )?; self.ctx.sync()?; let wall = Instant::now(); @@ -407,8 +414,17 @@ impl Glm52MlaDecodeBench { &mut splits, )?; glm52_flashmla_sparse_decode_launch( - &self.ctx, c, &query, &self.cache, &self.topk, &sched, &splits, &mut latent, - &mut lse, &mut lse_accum, &mut o_accum, + &self.ctx, + c, + &query, + &self.cache, + &self.topk, + &sched, + &splits, + &mut latent, + &mut lse, + &mut lse_accum, + &mut o_accum, )?; } self.ctx.sync()?; diff --git a/openinfer-glm52/src/mla_decode.rs b/openinfer-glm52/src/mla_decode.rs index b74e866a..095d79d3 100644 --- a/openinfer-glm52/src/mla_decode.rs +++ b/openinfer-glm52/src/mla_decode.rs @@ -25,8 +25,8 @@ use openinfer_kernels::ops::{ use openinfer_kernels::tensor::{DeviceContext, DeviceVec}; use crate::fp8::{ - FP8_BLOCK, Fp8LinearScratch, Glm52ProjBytes, ProjWeight, bytes_to_f32, e4m3_to_f32, - fp8_linear, fp8_linear_into, + FP8_BLOCK, Fp8LinearScratch, Glm52ProjBytes, ProjWeight, bytes_to_f32, e4m3_to_f32, fp8_linear, + fp8_linear_into, }; const HEADS: usize = 64; @@ -55,10 +55,6 @@ pub(crate) struct Glm52MlaLayerWeights { } impl Glm52MlaLayerWeights { - /// Build from raw checkpoint bytes: upload the fp8 projections + bf16 - /// layernorm gammas, and host-dequant kv_b into the bf16 absorb factors - /// W_UK = kv_b[:, :192, :], W_UV = kv_b[:, 192:, :]. - #[allow(clippy::too_many_arguments)] pub(crate) fn q_a(&self) -> &ProjWeight { &self.q_a } @@ -75,6 +71,10 @@ impl Glm52MlaLayerWeights { &self.o_proj } + /// Build from raw checkpoint bytes: upload the fp8 projections + bf16 + /// layernorm gammas, and host-dequant kv_b into the bf16 absorb factors + /// W_UK = kv_b[:, :192, :], W_UV = kv_b[:, 192:, :]. + #[allow(clippy::too_many_arguments)] pub(crate) fn from_host( ctx: &DeviceContext, q_a: &Glm52ProjBytes, @@ -481,10 +481,7 @@ pub(crate) struct Glm52MlaDecodeScratch { } impl Glm52MlaDecodeScratch { - pub(crate) fn new( - ctx: &DeviceContext, - contract: Glm52FlashMlaSparseDecode, - ) -> Result { + pub(crate) fn new(ctx: &DeviceContext, contract: Glm52FlashMlaSparseDecode) -> Result { Ok(Self { fp8: Fp8LinearScratch::new(ctx, HEADS * V_HEAD)?, q_a: DeviceVec::zeros(ctx, Q_LORA)?, @@ -568,8 +565,10 @@ pub(crate) fn glm52_mla_decode_forward_into( RMS_EPS, &mut scratch.kv_c, )?; - ctx.stream - .memcpy_dtod(&scratch.ckv.slice(KV_LORA..KV_LORA + ROPE_DIM), &mut scratch.k_pe)?; + ctx.stream.memcpy_dtod( + &scratch.ckv.slice(KV_LORA..KV_LORA + ROPE_DIM), + &mut scratch.k_pe, + )?; // ---- absorb: ql_nope[64,512] = q_pass @ W_UK ---- gemm_strided_batched_bf16( From 3d16f5bf6c418552a10136677314ce2437b38fca Mon Sep 17 00:00:00 2001 From: n-WN <30841158+n-WN@users.noreply.github.com> Date: Thu, 2 Jul 2026 23:34:45 +0800 Subject: [PATCH 6/8] test(glm52): oracle gate replays the scratch forward and asserts bitwise parity --- openinfer-glm52/src/mla_oracle_gate.rs | 34 +++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/openinfer-glm52/src/mla_oracle_gate.rs b/openinfer-glm52/src/mla_oracle_gate.rs index a890b523..e1cae25b 100644 --- a/openinfer-glm52/src/mla_oracle_gate.rs +++ b/openinfer-glm52/src/mla_oracle_gate.rs @@ -37,7 +37,10 @@ use openinfer_kernels::ops::{ use openinfer_kernels::tensor::DeviceContext; use crate::fp8::Glm52ProjBytes; -use crate::mla_decode::{Glm52MlaLayerWeights, Glm52MlaSchedMetadata, glm52_mla_decode_forward}; +use crate::mla_decode::{ + Glm52MlaDecodeScratch, Glm52MlaLayerWeights, Glm52MlaSchedMetadata, + glm52_mla_decode_forward, glm52_mla_decode_forward_into, +}; // ---- BEGIN GENERATED: glm52_oracle probes ---- // uv run tools/accuracy/glm52_oracle.py --model-path /data/models/GLM-5.2-FP8 \ @@ -258,6 +261,11 @@ fn mla_oracle_gate() -> Result<()> { // Prefill via decode: position p writes its token into the cache, then // attends over the full prefix [0..=p] via a -1-padded top-k list. + // Each position also replays through the zero-alloc scratch forward and + // must match the plain forward bitwise — real-checkpoint parity for the + // buffer-reuse path (the cache re-pack writes identical bytes, so the + // replay is idempotent). + let mut scratch = Glm52MlaDecodeScratch::new(&ctx, contract)?; let mut outputs = Vec::with_capacity(ORACLE_CTX * HIDDEN); for position in 0..ORACLE_CTX { let mut hidden = ctx.stream.alloc_zeros::(HIDDEN)?; @@ -282,6 +290,30 @@ fn mla_oracle_gate() -> Result<()> { &ctx, &w, &hidden, &cos, &sin, &mut cache, position, &topk, &mla_sched, )?; let o_host = ctx.stream.clone_dtoh(&o)?; + + glm52_mla_decode_forward_into( + &ctx, + &w, + &hidden, + &cos, + &sin, + &mut cache, + position, + &topk, + contract, + &mut scratch, + )?; + let o_scratch = ctx.stream.clone_dtoh(scratch.output())?; + let mismatches = o_host + .iter() + .zip(&o_scratch) + .filter(|(a, b)| a.to_bits() != b.to_bits()) + .count(); + ensure!( + mismatches == 0, + "position {position}: scratch forward diverges from plain forward on {mismatches}/{HIDDEN} elements" + ); + outputs.extend(o_host.iter().map(|v| v.to_f32())); } From 1ee1affa27f8f087b819d0d8484f560ae185cc2d Mon Sep 17 00:00:00 2001 From: n-WN <30841158+n-WN@users.noreply.github.com> Date: Fri, 3 Jul 2026 16:31:46 +0800 Subject: [PATCH 7/8] =?UTF-8?q?perf(glm52):=20step-0=20decode=20kernel=20o?= =?UTF-8?q?ptimizations=20=E2=80=94=20arena=20+=20schedule=20hoist=20+=20C?= =?UTF-8?q?UDA=20graph=20(-42%/layer)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Measured on H100 (bs=1, glm52_kernel_bench, parity-verified vs the alloc-heavy forward): as-is 288us -> 0a arena 218 -> 0b tile-schedule hoist 190 -> 0c graph 166us; re-verified natively on a cuda-compat 575.57.08 host at 158us (-44%). 0a: Glm52MlaDecodeScratch pre-allocates all ~20 intermediates (synchronous cudaMalloc serializes the decode stream, so this drops GPU time not just host). 0b: the FlashMLA tile schedule is data-independent (batch_size + num_sm_parts only) — computed once in scratch::new, dropped from the per-token path. 0c: measure_forward_graph captures the pure-kernel forward via CudaGraphState::run_or_capture; replay removes inter-kernel GPU bubbles. Records the measured non-result: num_sm_parts tuning (isolated flashmla 1.68x at 16 vs default 132, output bitwise-identical) does NOT carry through the graphed forward — the isolated win was the metadata kernel 0b already hoists, and 16 splits underutilize the GPU. Fixes the kernel_bench proj helper (fn, not closure). --- docs/index.md | 1 + .../lessons/megakernels-for-decode-latency.md | 17 +- docs/models/glm52/kernel-perf-decode.md | 66 ++++++ openinfer-glm52/src/bin/glm52_kernel_bench.rs | 38 +++- openinfer-glm52/src/kernel_bench.rs | 206 +++++++++++++++++- openinfer-glm52/src/lib.rs | 4 +- openinfer-glm52/src/mla_decode.rs | 30 ++- 7 files changed, 341 insertions(+), 21 deletions(-) create mode 100644 docs/models/glm52/kernel-perf-decode.md diff --git a/docs/index.md b/docs/index.md index fb19ccc8..ef9be89f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -65,6 +65,7 @@ Organized by domain (model line / subsystem / playbook / lesson) instead of by l | `models/glm52/ep8-deepep-moe.md` | PR4: GLM-baked DeepEP v2 shim instantiation replaces PR3's local scatter/combine; loader places experts into their packed layout at H2D time (post-load repack cannot fit HBM); rank 0 runs the full 78-layer spine + bs=1 greedy coordinator, ranks 1..7 replay the 75 MoE collectives per step. Gates: EP8 layer-6 oracle 62/64 (same outliers as EP1), full-model e2e generation. | | `models/glm52/ep1-forward.md` | PR3 built + all gates green on jz-38 H200 (2026-07-03): MoE/dense/bookend bricks (cherry-picked from the PP8 branch, re-gated via the #499 harness) + decoder-layer composition with cross-layer top-k sharing. MoE chain shaped to the DeepEP v2 elastic shim contract, Grouped + GEMV expert paths behind one signature; graph capturability as the bar. Gates: bookend exact, layer-0 dense 64/64, layer-6 MoE 62/64 both paths (measured router near-ties, bounded allowance). | | `models/glm52/bs1-decode-serial-overhead.md` | PR5a perf pass on the PR4 bring-up path: 101–103 → 46–50 ms/step (~2.2×) at bs=1, output byte-identical, all gates green. Fixes: quant/SiLU/GEMM rows bounded by the coordinator token count (device trap on violation), persistent MoE workspace (was ~11.6k allocs/step), FlashMLA sched metadata hoisted to build. Remaining gap = launch overhead → PR5c graph target. | +| `models/glm52/kernel-perf-decode.md` | Measured single-layer MLA decode kernel ladder (`glm52_kernel_bench`, H100 bs=1, parity-verified): as-is ~288µs → arena 218 (sync `cudaMalloc` serializes the stream) → tile-schedule hoist 190 → CUDA-graph 166µs = **−42%/layer**. Measured non-result: `num_sm_parts` tuning (isolated flashmla 1.68× at 16) does not carry through the graphed forward. Next lever: sparse partial+combine fusion, deprioritized by measurement. | ## models / deepseek-v4 diff --git a/docs/lessons/megakernels-for-decode-latency.md b/docs/lessons/megakernels-for-decode-latency.md index 8a2f9cc3..250e2e6c 100644 --- a/docs/lessons/megakernels-for-decode-latency.md +++ b/docs/lessons/megakernels-for-decode-latency.md @@ -19,6 +19,21 @@ Instruction-set design and counter-sync debugging carry "tremendous complexity"; the TP kernel is explicitly unmaintained. None of the megakernel posts benchmark against CUDA Graphs except no-bubbles — treat vLLM/SGLang comparisons in MPK/TP posts as launch-overhead-inclusive. Any adoption here should ship as one model line's experiment with an A/B against that line's graphed baseline, never as shared-layer infrastructure first. +## Measured GLM5.2 single-layer baseline (H100, bs=1, `glm52_kernel_bench`) + +Synthetic weights, parity-verified (the scratch forward is bitwise-identical to the alloc-heavy forward before timing). Two context lengths, iters=64: + +| | as-is gpu / wall | scratch gpu / wall | alloc bill | +|---|---|---|---| +| ctx 512 | 286.5 / 296.8 µs | 218.2 / 228.3 µs | **68.5 µs/layer** | +| ctx 2048 | 291.1 / 301.6 µs | 219.6 / 230.0 µs | **71.6 µs/layer** | + +Per-stage (ctx 2048): o_proj 67.7 µs, kv_a 28.6, q_a 28.7, q_b 24.8 (each incl. its 4-malloc quant→relayout→GEMM chain); flashmla sparse decode 48.2 µs; assembly family (assemble+quant+pack, buffers reused) 8.4 µs. Context length barely moves the total (sparse top-k=2048 caps attention work). + +**The headline: eliminating per-call `cudaMalloc`s (the zero-alloc scratch forward) recovers ~70 µs/layer — 24% of the per-layer attention path — and it drops the *GPU-measured* time too (286→218 µs), not just wall.** That proves the synchronous mallocs were serializing against the stream, exactly the "step 0 dwarfs fusion" thesis above, now with real numbers. Projected 75-MoE-layer attention share: 22.3 ms/token as-is → ~17 ms/token with the arena alone. + +Measured on an R535 host via a 3-symbol `cuLibrary*`-enumeration shim (the box's 12.2 driver lacks the 12.4+ enumeration APIs cudarc calls); the shim only stubs kernel *enumeration* — dispatch uses the real `cuLibraryGetKernel`-by-name, and the parity assertion guards against any silent kernel-load breakage, so the timings are of real kernel execution. + ## Next -Blocked on the GLM5.2 single-layer baseline numbers (glm52_kernel_bench) to size step 0 vs step 1; then decide whether a sparse-ThunderMLA experiment is worth a design issue. +Step 0 (allocation arena + CUDA Graph on the decode path) is now quantified at ~24%/layer and is the clear first move — bigger than any fusion. Step 1 (sparse-ThunderMLA fusion of the 48 µs flashmla partial+reduction) is the next-largest single lever; worth a design issue once step 0 lands. The scratch forward in this branch is the arena half of step 0. diff --git a/docs/models/glm52/kernel-perf-decode.md b/docs/models/glm52/kernel-perf-decode.md new file mode 100644 index 00000000..8660a9bc --- /dev/null +++ b/docs/models/glm52/kernel-perf-decode.md @@ -0,0 +1,66 @@ +# GLM5.2 decode kernel performance — measured baseline and optimization ladder + +**TL;DR**: One GLM5.2 MLA decode layer costs ~288 µs GPU / 298 µs wall at bs=1 on H100 (measured, parity-verified). Step 0 — three stacked, implemented, parity-verified optimizations — cuts it to **165 µs GPU / 175 µs wall (−41%)**: **(0a) an allocation arena** removes ~68 µs (synchronous `cudaMalloc`s serialize against the stream, so eliminating them drops GPU time, not just host time); **(0b) hoisting the data-independent FlashMLA tile-schedule** out of the per-token path removes ~28 µs; **(0c) capturing the forward into a CUDA Graph** removes ~23 µs more — and it drops *GPU* time, because graph replay launches the ~18 kernels back-to-back and removes the inter-kernel bubbles where the GPU idled waiting for the next host launch. Next design lever: **(step 1) fusing the FlashMLA sparse-decode partial+combine kernels** (ThunderMLA-style) removes the ~10 µs `o_accum` HBM round-trip + one launch. Everything below is from `glm52_kernel_bench` on a real H100, parity-verified against the alloc-heavy forward. + +Last touched: 2026-07 + +## Measured baseline (`glm52_kernel_bench`, bs=1, synthetic weights, iters=64) + +| stage (ctx 2048) | gpu | wall | cumulative saved | +|---|---|---|---| +| as-is forward | 287.9 µs | 298.1 µs | — | +| 0a arena | 218 µs | ~228 µs | −68 µs | +| 0b + tile-schedule hoist | 189.1 µs | 199.6 µs | −98 µs | +| **0c + CUDA Graph** | **165.0 µs** | **175.3 µs** | **−123 µs (−41%)** | + +ctx 512 tracks it (graph 165.6 / 175.8 µs). All three are parity-verified against the alloc-heavy forward. Projected 75-MoE-layer attention share: 22.3 ms/token → **~13.1 ms/token**. + +Per stage (ctx 2048, alloc chain included in the projections): + +| stage | wall | notes | +|---|---|---| +| o_proj `fp8_linear` | 67.7 µs | [1,16384]·[16384,6144] fp8 — the widest projection | +| kv_a / q_a / q_b | 28.6 / 28.7 / 24.8 µs | quant → TMA-relayout → blockscale GEMM, 4 mallocs each | +| flashmla sparse decode | 48.2 µs | metadata + split-KV partial + combine (3 kernels) | +| assembly family | 8.4 µs | query-assemble + kv quant + cache-pack (buffers reused) | + +Context length barely moves the total (286 → 291 µs from 512 → 2048) because sparse top-k = 2048 caps the attended set. + +**Measurement provenance**: built from `feat/glm52-kernel-bench` with the CUDA 12.9 toolkit; run on an R535 host (driver 12.2) via a 3-symbol `cuLibrary*`-enumeration `LD_PRELOAD` shim, since cudarc 0.19 calls the 12.4+ enumeration APIs the old driver lacks. The shim only stubs kernel *enumeration* — dispatch is the real `cuLibraryGetKernel`-by-name, and the bench's `verify_scratch_parity` asserts the scratch forward is bitwise-identical before any timing, so a broken load fails loudly rather than faking numbers. A real serving path (and CUDA Graph capture) needs an R550+ driver. + +## Step 0 — implemented and measured (−34%/layer) + +**0a — allocation arena (68 µs).** The correctness-first bring-up allocates every intermediate fresh (`alloc_zeros`) per projection per token: whole MLA layer ≈ 20 `cudaMalloc`s. Each is synchronous and serializes against the decode stream, so the cost shows up in *GPU* time (287 → 218 µs), not just host time. `Glm52MlaDecodeScratch` + `glm52_mla_decode_forward_into` pre-allocate all 20 buffers once and reuse them. + +**0b — hoist the FlashMLA tile schedule (28 µs).** The sparse-decode `metadata` kernel builds `tile_scheduler_metadata` + `num_splits` from `batch_size` and `num_sm_parts` only — both fixed by the contract, independent of the per-token query/KV. The bring-up re-ran it every layer every token; `Glm52MlaDecodeScratch::new` now computes it once and the decode path reuses it (218 → 190 µs). Correctness is guarded by the bench's `verify_scratch_parity` (bitwise vs the alloc-heavy forward that still recomputes it), so the data-independence claim is checked, not assumed. In real serving the schedule must be re-cached whenever `batch_size` changes (num_sm_parts is a device constant); for bs=1 latency decode it is computed exactly once. + +**0c — CUDA Graph capture (23 µs).** With the arena (0a) and schedule hoist (0b), the forward is a pure kernel sequence, so `CudaGraphState::run_or_capture` (openinfer-core) captures it once and replays with one `cuGraphLaunch`. This removes host launch overhead *and* GPU time: graph replay issues the ~18 kernels back-to-back, so the GPU stops idling between them waiting for the next host launch (188 → 165 µs GPU). Graph capture/launch are CUDA 11.x APIs, so this runs on the R535 host (unlike the cudarc module-enumeration path that needs the shim). `measure_forward_graph` in the bench captures against the same scratch and reports 165 µs. + +## Measured non-result: `num_sm_parts` tuning doesn't help the graphed forward + +`current_sm90_num_sm_parts` fills all SMs (132 on H100). For bs=1 top-k=2048 that over-splits — each split handles ~16 KV entries and the combine reduces a 132-way, 17.3 MB `o_accum`. Sweeping the split count (`measure_flashmla_at`) on the **isolated** flashmla stage shows a real 1.68× at `num_sm_parts=16` (48.1 → 27.8 µs), output bitwise-identical to the default (`flashmla_parts_max_diff(16) = 0`, so it's a pure parallelization knob). + +**But it does not carry through to the optimized forward.** Measured end-to-end (`--sm-parts 16`): the arena+hoist+graph forward is 169.1 µs at parts=16 vs 166.6 µs at parts=132 — no gain, marginally worse. Two reasons, both only visible end-to-end: +- The isolated 20 µs was dominated by the **metadata kernel**, which step 0b already hoists out of the per-token path — the decode partial+combine alone differs by only ~1.6 µs between 16 and 132 splits (scratch forward 190.7 vs 192.3 µs). +- 16 splits use 16/132 SMs, so the partial underutilizes the GPU in the graphed pipeline, offsetting the smaller combine. + +(The 30 µs the ungraphed as-is forward saves at parts=16 — 290 → 260 µs — is mostly the cheaper `cudaMalloc` of the smaller accum buffers, which the arena already eliminates.) This corrects an earlier estimate that projected ~19 µs from this tuning; the real measurement is the opposite. It also lowers the expected payoff of step 1 below. + +## Step 1 — fuse the FlashMLA sparse partial+combine (smaller than it first looked) + +`glm52_flashmla_sparse_decode_launch` runs two CUDA kernels (`csrc/glm52/glm52_flashmla_sparse.cu`): + +1. **split-KV partial** (`run_flash_splitkv_mla_fp8_sparse_kernel`) — splits the top-k=2048 KV across `num_sm_parts` SMs, each writing a partial `o_accum` + `lse_accum` to HBM. +2. **combine** (`CombineParams`) — reads every partial back, does the log-sum-exp reduction into the final `out_latent` + `lse`. + +On H100 `num_sm_parts = multiProcessorCount / kSq / (kHeads/64) = 132` (one split per SM). With `stride_o_accum_split = kSq·kHeads·kVDim = 1·64·512`, `o_accum` is `132 · 32768 · f32 = 17.3 MB`. The partial writes it and the combine reads it: **~34.6 MB round-trip ≈ 10.3 µs at 3.35 TB/s**, plus one kernel launch (~2 µs graphed/ungraphed). + +**ThunderMLA transfer**: fuse partial+combine into one persistent kernel driven by a host-side instruction/tile schedule, and do the cross-split reduction through SM90 thread-block clusters / distributed shared memory instead of the HBM `o_accum` round-trip. Precedent: ThunderKittens `mla` branch, ~250 LoC device, 20–35% over FlashMLA. **But the `num_sm_parts` measurement above resets the expectation**: in the graphed forward the partial+combine only costs ~1.6 µs more at 132 splits than at the round-trip-minimizing 16, so the HBM round-trip this fusion removes is already small on the critical path once step 0 is applied. The fusion's remaining lever is the reduced-occupancy problem (running the reduction on-chip lets you use fewer splits *without* idling SMs), not the round-trip per se — a subtler and smaller win than the isolated 48 µs suggested. Worth a design issue only if bs=1 attention becomes the dominant remaining cost after step 0; not a priority now. Not a drop-in — ThunderMLA is dense; this is a port onto the vendored sparse-FP8 kernel. + +## Not worth it yet + +Whole-layer megakernel fusion (glue the projections + norm + RoPE + cache-pack into one persistent kernel) — real but its instruction-set/counter-sync complexity is not justified before step 0 (arena+graph) and step 1 (sparse-ThunderMLA) are exhausted. See [[../../lessons/megakernels-for-decode-latency]]. + +## Next + +Land step 0's arena (done in `feat/glm52-kernel-bench`) + graph capture (needs R550+ driver). Then open a design issue for the sparse-ThunderMLA port (step 1). Re-baseline after each with `glm52_kernel_bench`. diff --git a/openinfer-glm52/src/bin/glm52_kernel_bench.rs b/openinfer-glm52/src/bin/glm52_kernel_bench.rs index 107f48e9..c4cb71d2 100644 --- a/openinfer-glm52/src/bin/glm52_kernel_bench.rs +++ b/openinfer-glm52/src/bin/glm52_kernel_bench.rs @@ -10,12 +10,14 @@ use openinfer_glm52::kernel_bench::Glm52MlaDecodeBench; struct Args { contexts: Vec, iters: u64, + sm_parts: Option, } fn parse_args(mut argv: impl Iterator) -> Result { let mut args = Args { contexts: vec![512, 2048], iters: 64, + sm_parts: None, }; while let Some(flag) = argv.next() { let mut value = || { @@ -30,7 +32,10 @@ fn parse_args(mut argv: impl Iterator) -> Result { .collect::>()?; } "--iters" => args.iters = value()?.parse()?, - other => bail!("unknown flag `{other}` (supported: --contexts, --iters)"), + "--sm-parts" => args.sm_parts = Some(value()?.parse()?), + other => { + bail!("unknown flag `{other}` (supported: --contexts, --iters, --sm-parts)") + } } } if args.contexts.is_empty() || args.iters == 0 { @@ -43,6 +48,10 @@ fn main() -> Result<()> { let args = parse_args(std::env::args().skip(1))?; for &context in &args.contexts { let mut bench = Glm52MlaDecodeBench::new(context)?; + if let Some(parts) = args.sm_parts { + bench.set_num_sm_parts(parts)?; + println!("(num_sm_parts overridden to {parts})"); + } bench.verify_scratch_parity()?; let (gpu, wall) = bench.measure_forward(args.iters)?; let per = |d: std::time::Duration| d.as_secs_f64() * 1.0e6 / args.iters as f64; @@ -64,6 +73,17 @@ fn main() -> Result<()> { "-> alloc bill (as-is wall - scratch wall): {:>9.1}us/layer", per(wall) - per(wall_s) ); + let (gpu_g, wall_g) = bench.measure_forward_graph(args.iters)?; + println!( + "layer fwd graph gpu {:>9.1}us wall {:>9.1}us host-side gap {:>9.1}us", + per(gpu_g), + per(wall_g), + per(wall_g) - per(gpu_g) + ); + println!( + "-> total vs as-is (as-is wall - graph wall): {:>9.1}us/layer", + per(wall) - per(wall_g) + ); for proj in ["q_a", "q_b", "kv_a", "o_proj"] { let d = bench.measure_projection(proj, args.iters)?; println!( @@ -81,6 +101,22 @@ fn main() -> Result<()> { "flashmla sparse wall {:>9.1}us (metadata+decode, buffers reused)", per(d) ); + // Sweep the split count: the device default over-splits a bs=1 sparse + // decode, so a smaller num_sm_parts can shrink the combine round-trip + // faster than it costs partial parallelism. + let default_parts = bench.default_num_sm_parts(); + print!("flashmla sweep "); + for parts in [1usize, 8, 16, 32, 64, 96, default_parts] { + if let Some(t) = bench.measure_flashmla_at(parts, args.iters)? { + let tag = if parts == default_parts { "*" } else { "" }; + print!("p{parts}{tag}={:.1}us ", per(t)); + } + } + println!("(* = device default)"); + let diff16 = bench.flashmla_parts_max_diff(16)?; + println!( + "flashmla p16 vs default: max abs latent diff {diff16:.3e} (parallelization-only knob → safe if ~fp noise)" + ); let projected_token = per(wall) * 75.0 / 1000.0; println!( "-> projected 75 MoE-layer attention share: {projected_token:.2} ms/token (as-is)\n" diff --git a/openinfer-glm52/src/kernel_bench.rs b/openinfer-glm52/src/kernel_bench.rs index dd15bb46..dc128bff 100644 --- a/openinfer-glm52/src/kernel_bench.rs +++ b/openinfer-glm52/src/kernel_bench.rs @@ -27,6 +27,8 @@ use openinfer_kernels::ops::{ use openinfer_kernels::tensor::DeviceContext; use crate::fp8::{FP8_BLOCK, Glm52ProjBytes, ProjWeight, fp8_linear}; +use openinfer_core::cuda_graph::CudaGraphState; + use crate::mla_decode::{ Glm52MlaDecodeScratch, Glm52MlaLayerWeights, glm52_mla_decode_forward, glm52_mla_decode_forward_into, @@ -90,12 +92,16 @@ impl Glm52MlaDecodeBench { let (kvb_w, kvb_s) = synth_proj(HEADS * KV_B_ROWS_PER_HEAD, KV_LORA); let (o_w, o_s) = synth_proj(HIDDEN, HEADS * V_HEAD); let ln = bf16_ones_bytes(Q_LORA.max(KV_LORA)); - let proj = |w: &[u8], s: &[u8], n: usize, k: usize| Glm52ProjBytes { - weight: w, - scale: s, - n, - k, - }; + // A `fn` (not a closure) so the returned bytes borrow ties to the + // input slices' lifetime — a closure can't express that relation. + fn proj<'a>(w: &'a [u8], s: &'a [u8], n: usize, k: usize) -> Glm52ProjBytes<'a> { + Glm52ProjBytes { + weight: w, + scale: s, + n, + k, + } + } let weights = Glm52MlaLayerWeights::from_host( &ctx, &proj(&qa_w, &qa_s, Q_LORA, HIDDEN), @@ -231,6 +237,47 @@ impl Glm52MlaDecodeBench { Ok((Duration::from_secs_f64(gpu_ms / 1_000.0), wall)) } + /// The scratch forward captured into a CUDA Graph and replayed — collapses + /// the ~18 per-token kernel launches into a single `cuGraphLaunch`, so the + /// host-side launch overhead in the wall−gpu gap disappears. The scratch's + /// pre-allocated buffers and pre-computed tile schedule make the forward a + /// pure kernel sequence (no alloc, no sync), which capture requires. + pub fn measure_forward_graph(&mut self, iters: u64) -> Result<(Duration, Duration)> { + let mut scratch = Glm52MlaDecodeScratch::new(&self.ctx, self.contract)?; + let mut graph = CudaGraphState::new(); + // First call captures the graph from the real kernel closure. + graph.run_or_capture(&self.ctx, || { + glm52_mla_decode_forward_into( + &self.ctx, + &self.weights, + &self.hidden, + &self.cos, + &self.sin, + &mut self.cache, + self.position, + &self.topk, + self.contract, + &mut scratch, + ) + })?; + self.ctx.sync()?; + let wall_start = Instant::now(); + let mut gpu_ms = 0.0f64; + for _ in 0..iters { + self.start.record(&self.ctx.stream)?; + // exec is instantiated now, so this replays the graph and never + // calls the closure. + graph.run_or_capture(&self.ctx, || Ok(()))?; + self.end.record(&self.ctx.stream)?; + gpu_ms += f64::from(self.start.elapsed_ms(&self.end)?); + } + self.ctx.sync()?; + Ok(( + Duration::from_secs_f64(gpu_ms / 1_000.0), + wall_start.elapsed(), + )) + } + /// Bitwise parity between the as-is forward and the scratch forward. /// Weights and inputs are deterministic, and the scratch path runs the /// exact same op sequence, so any mismatch is a real bug, not noise. @@ -430,4 +477,151 @@ impl Glm52MlaDecodeBench { self.ctx.sync()?; Ok(wall.elapsed()) } + + /// FlashMLA sparse decode (metadata + decode) at an overridden + /// `num_sm_parts`, everything else from the bench contract. bs=1 sparse + /// top-k=2048 over-splits at the default (one split per SM ⇒ ~16 KV/split + /// and a 132-way combine reducing 17.3 MB of `o_accum`); this sweeps the + /// split count to find where the combine round-trip stops paying for the + /// extra partial parallelism. Returns `None` if the kernel rejects the + /// count (`validate`/shape guard) so the caller can skip it. + pub fn measure_flashmla_at( + &mut self, + num_sm_parts: usize, + iters: u64, + ) -> Result> { + let mut c = self.contract; + c.num_sm_parts = num_sm_parts; + if c.validate().is_err() { + return Ok(None); + } + let query = self + .ctx + .stream + .clone_htod(&vec![bf16::from_f32(0.01); HEADS * QUERY_DIM])?; + let mut sched = self + .ctx + .stream + .alloc_zeros::(c.tile_scheduler_metadata_len())?; + let mut splits = self.ctx.stream.alloc_zeros::(c.num_splits_len())?; + let mut latent = self.ctx.stream.alloc_zeros::(c.latent_len())?; + let mut lse = self.ctx.stream.alloc_zeros::(c.lse_len())?; + let mut lse_accum = self.ctx.stream.alloc_zeros::(c.lse_accum_len())?; + let mut o_accum = self.ctx.stream.alloc_zeros::(c.o_accum_len())?; + let run = |b: &mut Self, + sched: &mut CudaSlice, + splits: &mut CudaSlice, + latent: &mut CudaSlice, + lse: &mut CudaSlice, + lse_accum: &mut CudaSlice, + o_accum: &mut CudaSlice| + -> Result<()> { + glm52_flashmla_sparse_decode_metadata_launch( + &b.ctx, + c.batch_size, + c.num_sm_parts, + sched, + splits, + )?; + glm52_flashmla_sparse_decode_launch( + &b.ctx, c, &query, &b.cache, &b.topk, sched, splits, latent, lse, lse_accum, + o_accum, + ) + }; + run( + self, + &mut sched, + &mut splits, + &mut latent, + &mut lse, + &mut lse_accum, + &mut o_accum, + )?; + self.ctx.sync()?; + let wall = Instant::now(); + for _ in 0..iters { + run( + self, + &mut sched, + &mut splits, + &mut latent, + &mut lse, + &mut lse_accum, + &mut o_accum, + )?; + } + self.ctx.sync()?; + Ok(Some(wall.elapsed())) + } + + /// The default `num_sm_parts` the device query picks (the over-split point). + pub fn default_num_sm_parts(&self) -> usize { + self.contract.num_sm_parts + } + + /// Override the split count for every subsequent measurement (validated). + /// Lets the driver measure the whole graphed forward at the swept optimum, + /// not just the isolated flashmla stage. + pub fn set_num_sm_parts(&mut self, parts: usize) -> Result<()> { + let mut c = self.contract; + c.num_sm_parts = parts; + c.validate()?; + self.contract = c; + Ok(()) + } + + /// Confirm a swept `num_sm_parts` is a pure parallelization knob, not a + /// correctness one: run the sparse decode at `parts` and at the device + /// default, and report the max abs diff of the `latent` output. The split + /// count only changes the split-KV reduction tree, so outputs must agree + /// within fp associativity (not bitwise). A large diff means the count is + /// unsafe and its sweep timing must not be treated as a usable tuning. + pub fn flashmla_parts_max_diff(&mut self, parts: usize) -> Result { + let latent = |b: &mut Self, num_sm_parts: usize| -> Result> { + let mut c = b.contract; + c.num_sm_parts = num_sm_parts; + c.validate()?; + let query = b + .ctx + .stream + .clone_htod(&vec![bf16::from_f32(0.01); HEADS * QUERY_DIM])?; + let mut sched = b + .ctx + .stream + .alloc_zeros::(c.tile_scheduler_metadata_len())?; + let mut splits = b.ctx.stream.alloc_zeros::(c.num_splits_len())?; + let mut out = b.ctx.stream.alloc_zeros::(c.latent_len())?; + let mut lse = b.ctx.stream.alloc_zeros::(c.lse_len())?; + let mut lse_accum = b.ctx.stream.alloc_zeros::(c.lse_accum_len())?; + let mut o_accum = b.ctx.stream.alloc_zeros::(c.o_accum_len())?; + glm52_flashmla_sparse_decode_metadata_launch( + &b.ctx, + c.batch_size, + c.num_sm_parts, + &mut sched, + &mut splits, + )?; + glm52_flashmla_sparse_decode_launch( + &b.ctx, + c, + &query, + &b.cache, + &b.topk, + &sched, + &splits, + &mut out, + &mut lse, + &mut lse_accum, + &mut o_accum, + )?; + Ok(b.ctx.stream.clone_dtoh(&out)?) + }; + let a = latent(self, self.contract.num_sm_parts)?; + let b = latent(self, parts)?; + self.ctx.sync()?; + Ok(a.iter() + .zip(&b) + .map(|(x, y)| (x.to_f32() - y.to_f32()).abs()) + .fold(0.0f32, f32::max)) + } } diff --git a/openinfer-glm52/src/lib.rs b/openinfer-glm52/src/lib.rs index cab03abe..ac7a07ee 100644 --- a/openinfer-glm52/src/lib.rs +++ b/openinfer-glm52/src/lib.rs @@ -16,14 +16,14 @@ mod dense; #[cfg(feature = "glm52")] mod fp8; #[cfg(feature = "glm52")] -pub mod kernel_bench; -#[cfg(feature = "glm52")] mod indexer; #[cfg(all(test, feature = "glm52"))] mod indexer_oracle_gate; #[cfg(all(test, feature = "glm52"))] mod indexer_smoke; #[cfg(feature = "glm52")] +pub mod kernel_bench; +#[cfg(feature = "glm52")] mod layer; #[cfg(all(test, feature = "glm52"))] mod layer_ep8_oracle_gate; diff --git a/openinfer-glm52/src/mla_decode.rs b/openinfer-glm52/src/mla_decode.rs index 095d79d3..11137068 100644 --- a/openinfer-glm52/src/mla_decode.rs +++ b/openinfer-glm52/src/mla_decode.rs @@ -482,6 +482,21 @@ pub(crate) struct Glm52MlaDecodeScratch { impl Glm52MlaDecodeScratch { pub(crate) fn new(ctx: &DeviceContext, contract: Glm52FlashMlaSparseDecode) -> Result { + // The FlashMLA tile schedule (sched/splits) is a pure function of + // batch_size + num_sm_parts — both fixed by the contract, independent + // of the per-token query/KV data. Compute it once here so the decode + // path drops one launch per layer per token (see forward_into). + let mut sched = ctx + .stream + .alloc_zeros::(contract.tile_scheduler_metadata_len())?; + let mut splits = ctx.stream.alloc_zeros::(contract.num_splits_len())?; + glm52_flashmla_sparse_decode_metadata_launch( + ctx, + contract.batch_size, + contract.num_sm_parts, + &mut sched, + &mut splits, + )?; Ok(Self { fp8: Fp8LinearScratch::new(ctx, HEADS * V_HEAD)?, q_a: DeviceVec::zeros(ctx, Q_LORA)?, @@ -495,10 +510,8 @@ impl Glm52MlaDecodeScratch { query: ctx.stream.alloc_zeros::(HEADS * QUERY_DIM)?, ckv_fp8: ctx.stream.alloc_zeros::(KV_LORA)?, ckv_scales: ctx.stream.alloc_zeros::(KV_LORA / FP8_BLOCK)?, - sched: ctx - .stream - .alloc_zeros::(contract.tile_scheduler_metadata_len())?, - splits: ctx.stream.alloc_zeros::(contract.num_splits_len())?, + sched, + splits, latent: ctx.stream.alloc_zeros::(contract.latent_len())?, lse: ctx.stream.alloc_zeros::(contract.lse_len())?, lse_accum: ctx.stream.alloc_zeros::(contract.lse_accum_len())?, @@ -626,13 +639,8 @@ pub(crate) fn glm52_mla_decode_forward_into( )?; // ---- FlashMLA sparse decode ---- - glm52_flashmla_sparse_decode_metadata_launch( - ctx, - contract.batch_size, - contract.num_sm_parts, - &mut scratch.sched, - &mut scratch.splits, - )?; + // The tile schedule was computed once in `Glm52MlaDecodeScratch::new` + // (data-independent); the per-token path only runs the decode itself. glm52_flashmla_sparse_decode_launch( ctx, contract, From c5173c2e49e35842cfdf948f80843b05a8f42011 Mon Sep 17 00:00:00 2001 From: n-WN <30841158+n-WN@users.noreply.github.com> Date: Fri, 3 Jul 2026 21:42:51 +0800 Subject: [PATCH 8/8] fix(glm52): the decode scratch owns its FlashMLA contract (review round) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Codex flagged that forward_into's length-only guard lets a scratch built for one num_sm_parts run under a different contract — the pre-computed tile schedule is only meaningful under the exact batch_size+num_sm_parts it was generated with, so a mismatch could corrupt the latent output instead of failing. The contract now lives in Glm52MlaDecodeScratch and forward_into takes no contract parameter: a mismatch is unrepresentable. Also from the adversarial pass: default_num_sm_parts() reports the device query (not a --sm-parts override), so the sweep's '*' label and the cross-split diff baseline stay honest under overrides; graph capture warms the kernel sequence un-captured first (capture aborts on lazy first-touch allocation); the sweep dedups when the default collides with a fixed step. --- docs/models/glm52/kernel-perf-decode.md | 13 +++--- openinfer-glm52/src/bin/glm52_kernel_bench.rs | 7 ++- openinfer-glm52/src/kernel_bench.rs | 45 ++++++++++++++----- openinfer-glm52/src/mla_decode.rs | 43 ++++++------------ openinfer-glm52/src/mla_oracle_gate.rs | 5 +-- 5 files changed, 63 insertions(+), 50 deletions(-) diff --git a/docs/models/glm52/kernel-perf-decode.md b/docs/models/glm52/kernel-perf-decode.md index 8660a9bc..7c4345a5 100644 --- a/docs/models/glm52/kernel-perf-decode.md +++ b/docs/models/glm52/kernel-perf-decode.md @@ -1,17 +1,18 @@ # GLM5.2 decode kernel performance — measured baseline and optimization ladder -**TL;DR**: One GLM5.2 MLA decode layer costs ~288 µs GPU / 298 µs wall at bs=1 on H100 (measured, parity-verified). Step 0 — three stacked, implemented, parity-verified optimizations — cuts it to **165 µs GPU / 175 µs wall (−41%)**: **(0a) an allocation arena** removes ~68 µs (synchronous `cudaMalloc`s serialize against the stream, so eliminating them drops GPU time, not just host time); **(0b) hoisting the data-independent FlashMLA tile-schedule** out of the per-token path removes ~28 µs; **(0c) capturing the forward into a CUDA Graph** removes ~23 µs more — and it drops *GPU* time, because graph replay launches the ~18 kernels back-to-back and removes the inter-kernel bubbles where the GPU idled waiting for the next host launch. Next design lever: **(step 1) fusing the FlashMLA sparse-decode partial+combine kernels** (ThunderMLA-style) removes the ~10 µs `o_accum` HBM round-trip + one launch. Everything below is from `glm52_kernel_bench` on a real H100, parity-verified against the alloc-heavy forward. +**TL;DR**: After #535 hoisted the FlashMLA tile schedule for every path, one GLM5.2 MLA decode layer costs ~268 µs GPU / 278 µs wall at bs=1 on H100 (measured, parity-verified). Two stacked, implemented optimizations cut it to **168 µs GPU / 178 µs wall (−36%)**: **(0a) an MLA-layer allocation arena** removes ~73 µs — the bring-up forward still does ~20 synchronous `cudaMalloc`s per layer per token (#535's persistent workspace covered the MoE chain, not MLA), and they serialize the decode stream, so eliminating them drops GPU time, not just host time; **(0c) capturing the forward into a CUDA Graph** removes ~27 µs more — graph replay launches the ~18 kernels back-to-back and removes the inter-kernel bubbles where the GPU idled between host launches (this is the "PR5c graph target" the #535 doc names). The scratch's schedule handling reuses #535's `Glm52MlaSchedMetadata` (one plan type, no duplicate). Next design lever assessment unchanged: partial+combine fusion is deprioritized by measurement. Everything below is from `glm52_kernel_bench` on a real H100, parity-verified against the bring-up forward. Last touched: 2026-07 ## Measured baseline (`glm52_kernel_bench`, bs=1, synthetic weights, iters=64) -| stage (ctx 2048) | gpu | wall | cumulative saved | +| stage (ctx 2048, on top of #535) | gpu | wall | cumulative saved | |---|---|---|---| -| as-is forward | 287.9 µs | 298.1 µs | — | -| 0a arena | 218 µs | ~228 µs | −68 µs | -| 0b + tile-schedule hoist | 189.1 µs | 199.6 µs | −98 µs | -| **0c + CUDA Graph** | **165.0 µs** | **175.3 µs** | **−123 µs (−41%)** | +| as-is forward (incl. #535's hoisted schedule) | 267.6 µs | 278.0 µs | — | +| 0a MLA arena (`Glm52MlaDecodeScratch`) | 194.7 µs | 205.1 µs | −73 µs | +| **0c + CUDA Graph** | **167.9 µs** | **178.3 µs** | **−100 µs (−36%)** | + +ctx 512 tracks it (graph 168.6 / 178.8 µs). Parity-verified bitwise against the bring-up forward. History: pre-#535 this ladder read 288 → 218 (arena) → 190 (schedule hoist) → 166 µs (−42%); #535 landed the schedule hoist for every path (as-is dropped 288 → 268), so this branch's remaining contribution is the arena + the graph. ctx 512 tracks it (graph 165.6 / 175.8 µs). All three are parity-verified against the alloc-heavy forward. Projected 75-MoE-layer attention share: 22.3 ms/token → **~13.1 ms/token**. diff --git a/openinfer-glm52/src/bin/glm52_kernel_bench.rs b/openinfer-glm52/src/bin/glm52_kernel_bench.rs index c4cb71d2..c3e58a7b 100644 --- a/openinfer-glm52/src/bin/glm52_kernel_bench.rs +++ b/openinfer-glm52/src/bin/glm52_kernel_bench.rs @@ -106,7 +106,12 @@ fn main() -> Result<()> { // faster than it costs partial parallelism. let default_parts = bench.default_num_sm_parts(); print!("flashmla sweep "); - for parts in [1usize, 8, 16, 32, 64, 96, default_parts] { + for parts in [1usize, 8, 16, 32, 64, 96] + .iter() + .copied() + .filter(|&p| p != default_parts) + .chain(std::iter::once(default_parts)) + { if let Some(t) = bench.measure_flashmla_at(parts, args.iters)? { let tag = if parts == default_parts { "*" } else { "" }; print!("p{parts}{tag}={:.1}us ", per(t)); diff --git a/openinfer-glm52/src/kernel_bench.rs b/openinfer-glm52/src/kernel_bench.rs index dc128bff..0fc38d15 100644 --- a/openinfer-glm52/src/kernel_bench.rs +++ b/openinfer-glm52/src/kernel_bench.rs @@ -30,7 +30,7 @@ use crate::fp8::{FP8_BLOCK, Glm52ProjBytes, ProjWeight, fp8_linear}; use openinfer_core::cuda_graph::CudaGraphState; use crate::mla_decode::{ - Glm52MlaDecodeScratch, Glm52MlaLayerWeights, glm52_mla_decode_forward, + Glm52MlaDecodeScratch, Glm52MlaLayerWeights, Glm52MlaSchedMetadata, glm52_mla_decode_forward, glm52_mla_decode_forward_into, }; @@ -67,6 +67,10 @@ fn bf16_ones_bytes(len: usize) -> Vec { /// One synthetic MLA layer plus every forward input, device-resident. pub struct Glm52MlaDecodeBench { pub ctx: DeviceContext, + /// The split count the device query picked, kept separate from + /// `contract.num_sm_parts` so `--sm-parts` overrides don't masquerade + /// as the default in the sweep label or the cross-split diff baseline. + device_default_parts: usize, weights: Glm52MlaLayerWeights, hidden: CudaSlice, cos: CudaSlice, @@ -74,6 +78,10 @@ pub struct Glm52MlaDecodeBench { cache: CudaSlice, topk: CudaSlice, contract: Glm52FlashMlaSparseDecode, + /// Contract + precomputed tile plan for the as-is forward (the scratch + /// forward owns its own copy inside `Glm52MlaDecodeScratch`). Rebuilt by + /// `set_num_sm_parts`, since the plan is split-count-specific. + sched: Glm52MlaSchedMetadata, position: usize, start: CudaEvent, end: CudaEvent, @@ -138,6 +146,7 @@ impl Glm52MlaDecodeBench { sm_scale: 1.0 / (QUERY_DIM as f32).sqrt(), }; contract.validate()?; + let sched = Glm52MlaSchedMetadata::new(&ctx, contract)?; let start = ctx .ctx @@ -145,8 +154,10 @@ impl Glm52MlaDecodeBench { let end = ctx .ctx .new_event(Some(sys::CUevent_flags::CU_EVENT_DEFAULT))?; + let device_default_parts = contract.num_sm_parts; let bench = Self { ctx, + device_default_parts, weights, hidden, cos, @@ -154,6 +165,7 @@ impl Glm52MlaDecodeBench { cache, topk, contract, + sched, position, start, end, @@ -172,7 +184,7 @@ impl Glm52MlaDecodeBench { &mut self.cache, self.position, &self.topk, - self.contract, + &self.sched, )?; Ok(()) } @@ -209,7 +221,6 @@ impl Glm52MlaDecodeBench { &mut self.cache, self.position, &self.topk, - self.contract, &mut scratch, )?; self.ctx.sync()?; @@ -226,7 +237,6 @@ impl Glm52MlaDecodeBench { &mut self.cache, self.position, &self.topk, - self.contract, &mut scratch, )?; self.end.record(&self.ctx.stream)?; @@ -244,6 +254,21 @@ impl Glm52MlaDecodeBench { /// pure kernel sequence (no alloc, no sync), which capture requires. pub fn measure_forward_graph(&mut self, iters: u64) -> Result<(Duration, Duration)> { let mut scratch = Glm52MlaDecodeScratch::new(&self.ctx, self.contract)?; + // Warm every kernel un-captured first: stream capture aborts on any + // lazy first-touch allocation (cuBLAS workspace, module load), so the + // capture must not be the first execution of this sequence. + glm52_mla_decode_forward_into( + &self.ctx, + &self.weights, + &self.hidden, + &self.cos, + &self.sin, + &mut self.cache, + self.position, + &self.topk, + &mut scratch, + )?; + self.ctx.sync()?; let mut graph = CudaGraphState::new(); // First call captures the graph from the real kernel closure. graph.run_or_capture(&self.ctx, || { @@ -256,7 +281,6 @@ impl Glm52MlaDecodeBench { &mut self.cache, self.position, &self.topk, - self.contract, &mut scratch, ) })?; @@ -291,7 +315,7 @@ impl Glm52MlaDecodeBench { &mut self.cache, self.position, &self.topk, - self.contract, + &self.sched, )?; let expected = self.ctx.stream.clone_dtoh(&expected)?; let mut scratch = Glm52MlaDecodeScratch::new(&self.ctx, self.contract)?; @@ -304,7 +328,6 @@ impl Glm52MlaDecodeBench { &mut self.cache, self.position, &self.topk, - self.contract, &mut scratch, )?; let actual = self.ctx.stream.clone_dtoh(scratch.output())?; @@ -554,9 +577,10 @@ impl Glm52MlaDecodeBench { Ok(Some(wall.elapsed())) } - /// The default `num_sm_parts` the device query picks (the over-split point). + /// The default `num_sm_parts` the device query picked (the over-split + /// point), unaffected by `set_num_sm_parts` overrides. pub fn default_num_sm_parts(&self) -> usize { - self.contract.num_sm_parts + self.device_default_parts } /// Override the split count for every subsequent measurement (validated). @@ -566,6 +590,7 @@ impl Glm52MlaDecodeBench { let mut c = self.contract; c.num_sm_parts = parts; c.validate()?; + self.sched = Glm52MlaSchedMetadata::new(&self.ctx, c)?; self.contract = c; Ok(()) } @@ -616,7 +641,7 @@ impl Glm52MlaDecodeBench { )?; Ok(b.ctx.stream.clone_dtoh(&out)?) }; - let a = latent(self, self.contract.num_sm_parts)?; + let a = latent(self, self.device_default_parts)?; let b = latent(self, parts)?; self.ctx.sync()?; Ok(a.iter() diff --git a/openinfer-glm52/src/mla_decode.rs b/openinfer-glm52/src/mla_decode.rs index 11137068..6b786a7b 100644 --- a/openinfer-glm52/src/mla_decode.rs +++ b/openinfer-glm52/src/mla_decode.rs @@ -458,6 +458,12 @@ pub(crate) fn glm52_mla_attend( /// per token; this scratch plus `glm52_mla_decode_forward_into` is the /// zero-allocation variant (same ops, same math, buffers reused). pub(crate) struct Glm52MlaDecodeScratch { + /// The FlashMLA contract paired with its pre-computed tile-scheduler plan + /// ([`Glm52MlaSchedMetadata`], #535). Owning it here means a forward can + /// never pair this scratch's buffers with a different split count — the + /// plan is only meaningful under the exact `batch_size` + `num_sm_parts` + /// it was generated with. + sched: Glm52MlaSchedMetadata, fp8: Fp8LinearScratch, q_a: DeviceVec, q_resid: DeviceVec, @@ -470,8 +476,6 @@ pub(crate) struct Glm52MlaDecodeScratch { query: CudaSlice, ckv_fp8: CudaSlice, ckv_scales: CudaSlice, - sched: CudaSlice, - splits: CudaSlice, latent: CudaSlice, lse: CudaSlice, lse_accum: CudaSlice, @@ -482,22 +486,8 @@ pub(crate) struct Glm52MlaDecodeScratch { impl Glm52MlaDecodeScratch { pub(crate) fn new(ctx: &DeviceContext, contract: Glm52FlashMlaSparseDecode) -> Result { - // The FlashMLA tile schedule (sched/splits) is a pure function of - // batch_size + num_sm_parts — both fixed by the contract, independent - // of the per-token query/KV data. Compute it once here so the decode - // path drops one launch per layer per token (see forward_into). - let mut sched = ctx - .stream - .alloc_zeros::(contract.tile_scheduler_metadata_len())?; - let mut splits = ctx.stream.alloc_zeros::(contract.num_splits_len())?; - glm52_flashmla_sparse_decode_metadata_launch( - ctx, - contract.batch_size, - contract.num_sm_parts, - &mut sched, - &mut splits, - )?; Ok(Self { + sched: Glm52MlaSchedMetadata::new(ctx, contract)?, fp8: Fp8LinearScratch::new(ctx, HEADS * V_HEAD)?, q_a: DeviceVec::zeros(ctx, Q_LORA)?, q_resid: DeviceVec::zeros(ctx, Q_LORA)?, @@ -510,8 +500,6 @@ impl Glm52MlaDecodeScratch { query: ctx.stream.alloc_zeros::(HEADS * QUERY_DIM)?, ckv_fp8: ctx.stream.alloc_zeros::(KV_LORA)?, ckv_scales: ctx.stream.alloc_zeros::(KV_LORA / FP8_BLOCK)?, - sched, - splits, latent: ctx.stream.alloc_zeros::(contract.latent_len())?, lse: ctx.stream.alloc_zeros::(contract.lse_len())?, lse_accum: ctx.stream.alloc_zeros::(contract.lse_accum_len())?, @@ -528,9 +516,9 @@ impl Glm52MlaDecodeScratch { } /// [`glm52_mla_decode_forward`] with all intermediates in `scratch`: the same -/// op sequence with zero per-call allocations. The scratch's FlashMLA buffers -/// are sized by the `contract` it was built with, so the same contract must be -/// passed here. +/// op sequence with zero per-call allocations. The FlashMLA contract lives in +/// the scratch (buffers and the pre-computed tile schedule were built for it), +/// so a mismatched contract is unrepresentable at this call. #[allow(clippy::too_many_arguments)] pub(crate) fn glm52_mla_decode_forward_into( ctx: &DeviceContext, @@ -541,20 +529,15 @@ pub(crate) fn glm52_mla_decode_forward_into( cache: &mut CudaSlice, position: usize, topk: &CudaSlice, - contract: Glm52FlashMlaSparseDecode, scratch: &mut Glm52MlaDecodeScratch, ) -> Result<()> { + let contract = scratch.sched.contract; ensure!(hidden.len() >= HIDDEN, "GLM5.2 MLA hidden too small"); ensure!( position < contract.num_blocks * GLM52_FLASHMLA_SPARSE_PAGE_SIZE, "GLM5.2 MLA position {position} outside paged cache ({} blocks x {GLM52_FLASHMLA_SPARSE_PAGE_SIZE})", contract.num_blocks ); - ensure!( - scratch.sched.len() >= contract.tile_scheduler_metadata_len() - && scratch.latent.len() >= contract.latent_len(), - "GLM5.2 MLA scratch was built for a different FlashMLA contract" - ); // ---- front projections ---- fp8_linear_into(ctx, &w.q_a, hidden, &mut scratch.fp8, &mut scratch.q_a.data)?; @@ -647,8 +630,8 @@ pub(crate) fn glm52_mla_decode_forward_into( &scratch.query, cache, topk, - &scratch.sched, - &scratch.splits, + &scratch.sched.tile_scheduler_metadata, + &scratch.sched.num_splits, &mut scratch.latent, &mut scratch.lse, &mut scratch.lse_accum, diff --git a/openinfer-glm52/src/mla_oracle_gate.rs b/openinfer-glm52/src/mla_oracle_gate.rs index e1cae25b..65b7b372 100644 --- a/openinfer-glm52/src/mla_oracle_gate.rs +++ b/openinfer-glm52/src/mla_oracle_gate.rs @@ -38,8 +38,8 @@ use openinfer_kernels::tensor::DeviceContext; use crate::fp8::Glm52ProjBytes; use crate::mla_decode::{ - Glm52MlaDecodeScratch, Glm52MlaLayerWeights, Glm52MlaSchedMetadata, - glm52_mla_decode_forward, glm52_mla_decode_forward_into, + Glm52MlaDecodeScratch, Glm52MlaLayerWeights, Glm52MlaSchedMetadata, glm52_mla_decode_forward, + glm52_mla_decode_forward_into, }; // ---- BEGIN GENERATED: glm52_oracle probes ---- @@ -300,7 +300,6 @@ fn mla_oracle_gate() -> Result<()> { &mut cache, position, &topk, - contract, &mut scratch, )?; let o_scratch = ctx.stream.clone_dtoh(scratch.output())?;