feat: multi-step forward unrolling + split backward for fused chunk GLA#147
feat: multi-step forward unrolling + split backward for fused chunk GLA#147sii-xinglong wants to merge 1 commit intomainfrom
Conversation
Port optimizations from Glaucis PR #74: Forward (tops/ops/common/fused_chunk.py): - Process num_steps (default 4) consecutive chunks per grid iteration via a Python for loop, reducing grid overhead by 4x - Auto-fallback: if NT % num_steps != 0, reduces to largest valid divisor - Precomputes loop-invariant constants (causal mask, g_gamma decay) Backward (tops/ops/simple_gla/fused_chunk.py): - Split monolithic backward into 2 passes to reduce register pressure - Pass 1: dv + dh propagation (sequential, "arbitrary" time dim) - Pass 2: dq + dk using materialized dh_states (fully parallel time dim) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
📝 WalkthroughWalkthroughModified the fused forward and backward Pallas kernels for chunk operations. The forward kernel introduces multi-step processing via a Changes
Sequence Diagram(s)sequenceDiagram
participant Launcher
participant FwdKernel as Fused Forward<br/>(num_steps)
participant Memory as VMEM/SMEM
Launcher->>FwdKernel: Dispatch grid (B, H, NK, NV, NT_OUTER)
activate FwdKernel
FwdKernel->>Memory: Load q, k, v (NUM_STEPS*BT)
loop step = 0 to NUM_STEPS-1
FwdKernel->>FwdKernel: Compute partial h, o<br/>(offset = step*BT)
alt step == 0
FwdKernel->>Memory: Load initial h_state
end
FwdKernel->>Memory: Update h_state in VMEM
end
alt i_t == NT_OUTER-1
FwdKernel->>Memory: Store final h_state
end
FwdKernel->>Memory: Write o (NUM_STEPS*BT)
deactivate FwdKernel
sequenceDiagram
participant Launcher
participant Pass1 as Pass 1: dv & dh<br/>(_bwd_dv_dh_kernel)
participant Pass2 as Pass 2: dq & dk<br/>(_bwd_dq_dk_kernel)
participant Scratch as VMEM Scratch
Launcher->>Pass1: Execute backward Pass 1
activate Pass1
Pass1->>Scratch: Propagate dh across chunks<br/>(sequential)
Pass1->>Scratch: Materialize dh_states<br/>(B, H, NT, K, V)
Pass1->>Scratch: Store dv output
deactivate Pass1
Launcher->>Pass2: Execute backward Pass 2
activate Pass2
Pass2->>Scratch: Load precomputed h, dh_states
Pass2->>Pass2: Compute dq, dk<br/>(parallel time dimension)
Pass2->>Scratch: Store dq, dk outputs
deactivate Pass2
Launcher->>Launcher: Collect (dq, dk, dv, dh0)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Code Review
This pull request optimizes the fused chunk forward and backward kernels. The forward kernel now processes multiple sub-chunks per grid iteration to reduce overhead, while the backward kernel is split into two passes to reduce register pressure and prevent spills. My feedback highlights the need to refactor the duplicated num_steps auto-adjustment logic into a shared utility and suggests ensuring explicit initialization of dh_states_ref in the backward pass to avoid potential issues with stale data.
| while num_steps > 1 and NT % num_steps != 0: | ||
| num_steps //= 2 |
| safe_g_diff = jnp.where(mask, b_g[:, None] - b_g[None, :], 0.0) | ||
| b_dA_gated = b_dA * jnp.exp(safe_g_diff) | ||
| # Store dh state BEFORE updating (Pass 2 needs dh at this time step) | ||
| dh_states_ref[0, 0, 0] = scratch_ref[...].astype(dh_states_ref.dtype) |
There was a problem hiding this comment.
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tops/ops/common/fused_chunk.py (1)
212-216: Redundant num_steps auto-adjustment logic.The
while num_steps > 1 and NT % num_steps != 0loop appears both here (lines 212-214) and in the publicfused_chunk_fwd(lines 371-373).This duplication is intentional since
fused_chunk_fwd_kernelcan be called directly (it's JIT'd separately), so both entry points need to handle invalidnum_steps. However, when called viafused_chunk_fwd, the adjustment runs twice with identical results.Consider extracting to a helper function for maintainability, or documenting that
fused_chunk_fwd_kernelexpects pre-validatednum_stepsand removing the check there.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/common/fused_chunk.py` around lines 212 - 216, The num_steps auto-adjustment loop is duplicated in fused_chunk_fwd_kernel and fused_chunk_fwd causing redundant execution when fused_chunk_fwd calls the kernel; extract that loop into a single helper (e.g., normalize_num_steps or adjust_num_steps) and replace the duplicated blocks in both fused_chunk_fwd_kernel and fused_chunk_fwd with a call to that helper (or, if you prefer the kernel to assume validated input, document that fused_chunk_fwd must call the helper and remove the loop from fused_chunk_fwd_kernel). Ensure the helper returns the adjusted num_steps (and any dependent values like NT_OUTER/WINDOW) so both functions use the same logic and avoid double-adjustment.tops/ops/simple_gla/fused_chunk.py (1)
286-296: Consider matching dh_states dtype to reduce memory/conversion overhead.
dh_statesis created withjnp.float32dtype (line 294), while other backward outputs likedvusev.dtype. Ifv.dtypeis already float32, this is fine. However, if inputs are bfloat16, the materializeddh_statestensor(B, H, NT, K, V)will be 2x larger than necessary in memory.This is likely intentional for numerical stability in the backward pass, but worth confirming this tradeoff is acceptable given the memory savings goal of the two-pass split.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/simple_gla/fused_chunk.py` around lines 286 - 296, The dh_states output is hardcoded to jnp.float32 in pass1_out_shapes which can bloat memory when inputs (v) are bfloat16; change the dh_states dtype to match v.dtype by using v.dtype (i.e., jax.ShapeDtypeStruct((B, H, NT, K, V), v.dtype)) in pass1_out_shapes (and similarly ensure any related uses in pass1_out_specs or downstream code that assume float32 are updated or explicitly guarded), or if you must keep float32 for numerical-stability reasons, add a clear code comment explaining this trade-off so the choice is explicit.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tops/ops/common/fused_chunk.py`:
- Around line 212-216: The num_steps auto-adjustment loop is duplicated in
fused_chunk_fwd_kernel and fused_chunk_fwd causing redundant execution when
fused_chunk_fwd calls the kernel; extract that loop into a single helper (e.g.,
normalize_num_steps or adjust_num_steps) and replace the duplicated blocks in
both fused_chunk_fwd_kernel and fused_chunk_fwd with a call to that helper (or,
if you prefer the kernel to assume validated input, document that
fused_chunk_fwd must call the helper and remove the loop from
fused_chunk_fwd_kernel). Ensure the helper returns the adjusted num_steps (and
any dependent values like NT_OUTER/WINDOW) so both functions use the same logic
and avoid double-adjustment.
In `@tops/ops/simple_gla/fused_chunk.py`:
- Around line 286-296: The dh_states output is hardcoded to jnp.float32 in
pass1_out_shapes which can bloat memory when inputs (v) are bfloat16; change the
dh_states dtype to match v.dtype by using v.dtype (i.e.,
jax.ShapeDtypeStruct((B, H, NT, K, V), v.dtype)) in pass1_out_shapes (and
similarly ensure any related uses in pass1_out_specs or downstream code that
assume float32 are updated or explicitly guarded), or if you must keep float32
for numerical-stability reasons, add a clear code comment explaining this
trade-off so the choice is explicit.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9e2c99cf-9985-49f8-87d2-36bba1d8c01f
📒 Files selected for processing (2)
tops/ops/common/fused_chunk.pytops/ops/simple_gla/fused_chunk.py
Summary
Port optimizations from Glaucis PR #74 to the current repository's fused chunk GLA kernels:
tops/ops/common/fused_chunk.py): Each grid iteration now processesnum_steps(default 4) consecutive chunks via a Python for loop, reducing grid overhead by 4x. Auto-fallback ensures correctness when NT is not divisible by num_steps.tops/ops/simple_gla/fused_chunk.py): The monolithic backward kernel is split into Pass 1 (dv + dh propagation, sequential) and Pass 2 (dq + dk, fully parallel), reducing register pressure.Key difference from upstream: sub-steps use a for loop instead of manual unrolling for maintainability.
Test plan
num_steps=4vsnum_steps=1produce identical results (max_diff=0.0)g,g_gamma, both, h0, and output_final_statecpu_chunk_simple_gla_fwd)cpu_chunk_simple_gla_bwd)🤖 Generated with Claude Code
Summary by CodeRabbit