Skip to content

feat: multi-step forward unrolling + split backward for fused chunk GLA#147

Open
sii-xinglong wants to merge 1 commit intomainfrom
feat/fused-chunk-multistep-bwd-split
Open

feat: multi-step forward unrolling + split backward for fused chunk GLA#147
sii-xinglong wants to merge 1 commit intomainfrom
feat/fused-chunk-multistep-bwd-split

Conversation

@sii-xinglong
Copy link
Copy Markdown
Contributor

@sii-xinglong sii-xinglong commented Apr 2, 2026

Summary

Port optimizations from Glaucis PR #74 to the current repository's fused chunk GLA kernels:

  • Forward multi-step unrolling (tops/ops/common/fused_chunk.py): Each grid iteration now processes num_steps (default 4) consecutive chunks via a Python for loop, reducing grid overhead by 4x. Auto-fallback ensures correctness when NT is not divisible by num_steps.
  • Backward split into 2 passes (tops/ops/simple_gla/fused_chunk.py): The monolithic backward kernel is split into Pass 1 (dv + dh propagation, sequential) and Pass 2 (dq + dk, fully parallel), reducing register pressure.

Key difference from upstream: sub-steps use a for loop instead of manual unrolling for maintainability.

Test plan

  • Forward: num_steps=4 vs num_steps=1 produce identical results (max_diff=0.0)
  • Forward: auto-fallback works for NT=1, NT=2
  • Forward: works with g, g_gamma, both, h0, and output_final_state
  • Forward: matches CPU reference (cpu_chunk_simple_gla_fwd)
  • Backward: split produces correct dq/dk/dv/dh0 (no NaN/Inf)
  • Backward: matches CPU reference (cpu_chunk_simple_gla_bwd)
  • Backward: works with h0/dht and NT=1
  • Verify speedup on TPU v7x

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Refactor
    • Optimized forward computation to process multiple consecutive time chunks per iteration, improving throughput.
    • Restructured backward computation into a two-pass pipeline for enhanced efficiency and parallelization.

Port optimizations from Glaucis PR #74:

Forward (tops/ops/common/fused_chunk.py):
- Process num_steps (default 4) consecutive chunks per grid iteration
  via a Python for loop, reducing grid overhead by 4x
- Auto-fallback: if NT % num_steps != 0, reduces to largest valid divisor
- Precomputes loop-invariant constants (causal mask, g_gamma decay)

Backward (tops/ops/simple_gla/fused_chunk.py):
- Split monolithic backward into 2 passes to reduce register pressure
- Pass 1: dv + dh propagation (sequential, "arbitrary" time dim)
- Pass 2: dq + dk using materialized dh_states (fully parallel time dim)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 2, 2026

📝 Walkthrough

Walkthrough

Modified the fused forward and backward Pallas kernels for chunk operations. The forward kernel introduces multi-step processing via a num_steps parameter to handle multiple consecutive time chunks per grid iteration. The backward kernel is split into a two-pass pipeline—Pass 1 computes gradients for values and hidden states, Pass 2 computes gradients for queries and keys using materialized intermediate states.

Changes

Cohort / File(s) Summary
Fused Forward Kernel Enhancement
tops/ops/common/fused_chunk.py
Added num_steps parameter (default 4) to process multiple consecutive time chunks per grid iteration. Restructured kernel logic to use inner for step in range(NUM_STEPS) loop with offset-based chunk indexing. Adjusted grid from (B, H, NK, NV, NT) to (B, H, NK, NV, NT_OUTER) where NT_OUTER = NT // num_steps. Added auto-adjustment logic to halve num_steps until divisibility is satisfied. Updated block specs with WINDOW = num_steps * BT and renamed output_final_state parameter to use_ht.
Backward Kernel Split into Two-Pass Pipeline
tops/ops/simple_gla/fused_chunk.py
Replaced monolithic _fused_chunk_bwd_kernel with two-pass backward pipeline. Pass 1 (_bwd_dv_dh_kernel) computes dv and propagates hidden-state gradients, materializing dh_states and dh0. Pass 2 (_bwd_dq_dk_kernel) computes dq and dk using precomputed forward hidden states and materialized intermediate dh_states. Updated launcher to orchestrate both passes with separate input/output specs and adjusted dimension semantics for parallelizability.

Sequence Diagram(s)

sequenceDiagram
    participant Launcher
    participant FwdKernel as Fused Forward<br/>(num_steps)
    participant Memory as VMEM/SMEM

    Launcher->>FwdKernel: Dispatch grid (B, H, NK, NV, NT_OUTER)
    activate FwdKernel
    FwdKernel->>Memory: Load q, k, v (NUM_STEPS*BT)
    loop step = 0 to NUM_STEPS-1
        FwdKernel->>FwdKernel: Compute partial h, o<br/>(offset = step*BT)
        alt step == 0
            FwdKernel->>Memory: Load initial h_state
        end
        FwdKernel->>Memory: Update h_state in VMEM
    end
    alt i_t == NT_OUTER-1
        FwdKernel->>Memory: Store final h_state
    end
    FwdKernel->>Memory: Write o (NUM_STEPS*BT)
    deactivate FwdKernel
Loading
sequenceDiagram
    participant Launcher
    participant Pass1 as Pass 1: dv & dh<br/>(_bwd_dv_dh_kernel)
    participant Pass2 as Pass 2: dq & dk<br/>(_bwd_dq_dk_kernel)
    participant Scratch as VMEM Scratch

    Launcher->>Pass1: Execute backward Pass 1
    activate Pass1
    Pass1->>Scratch: Propagate dh across chunks<br/>(sequential)
    Pass1->>Scratch: Materialize dh_states<br/>(B, H, NT, K, V)
    Pass1->>Scratch: Store dv output
    deactivate Pass1

    Launcher->>Pass2: Execute backward Pass 2
    activate Pass2
    Pass2->>Scratch: Load precomputed h, dh_states
    Pass2->>Pass2: Compute dq, dk<br/>(parallel time dimension)
    Pass2->>Scratch: Store dq, dk outputs
    deactivate Pass2

    Launcher->>Launcher: Collect (dq, dk, dv, dh0)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • fix bwd_dh with large B size #71 — Both PRs modify the SimpleGLA backward path with kernel tiling and data flow changes; the main PR splits backward into a two-pass dv/dh then dq/dk pipeline while the related PR updates backward-dh kernel carry/inputs.
  • Feat/fused bwd #142 — Both PRs modify the fused backward implementation in tops/ops/simple_gla/fused_chunk.py, including kernel/launcher behavior and public function semantics.
  • Perf/gla bwd #130 — Both PRs modify the chunked GLA backward path, changing how dh/dh0 and other gradients are computed within backward kernels.

Suggested reviewers

  • 0xaskr

🐰 A kernel hops through time, one step at a time,
Chunks now walk together, keeping perfect time!
Two passes dance backward, a graceful refrain,
Data flows cleaner through this gradient domain. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.85% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically summarizes the main changes: multi-step forward unrolling and split backward implementation for fused chunk GLA kernels.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/fused-chunk-multistep-bwd-split

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the fused chunk forward and backward kernels. The forward kernel now processes multiple sub-chunks per grid iteration to reduce overhead, while the backward kernel is split into two passes to reduce register pressure and prevent spills. My feedback highlights the need to refactor the duplicated num_steps auto-adjustment logic into a shared utility and suggests ensuring explicit initialization of dh_states_ref in the backward pass to avoid potential issues with stale data.

Comment on lines +372 to +373
while num_steps > 1 and NT % num_steps != 0:
num_steps //= 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The manual fallback logic for num_steps is duplicated in both the kernel launcher and the public API function. This should be refactored into a shared utility function to ensure consistency and maintainability.

safe_g_diff = jnp.where(mask, b_g[:, None] - b_g[None, :], 0.0)
b_dA_gated = b_dA * jnp.exp(safe_g_diff)
# Store dh state BEFORE updating (Pass 2 needs dh at this time step)
dh_states_ref[0, 0, 0] = scratch_ref[...].astype(dh_states_ref.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable dh_states_ref is being written to in Pass 1, but it is not explicitly initialized or cleared. While it is overwritten, ensuring it is zero-initialized or handled explicitly if dht is not provided could prevent potential issues with stale data in memory.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tops/ops/common/fused_chunk.py (1)

212-216: Redundant num_steps auto-adjustment logic.

The while num_steps > 1 and NT % num_steps != 0 loop appears both here (lines 212-214) and in the public fused_chunk_fwd (lines 371-373).

This duplication is intentional since fused_chunk_fwd_kernel can be called directly (it's JIT'd separately), so both entry points need to handle invalid num_steps. However, when called via fused_chunk_fwd, the adjustment runs twice with identical results.

Consider extracting to a helper function for maintainability, or documenting that fused_chunk_fwd_kernel expects pre-validated num_steps and removing the check there.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/fused_chunk.py` around lines 212 - 216, The num_steps
auto-adjustment loop is duplicated in fused_chunk_fwd_kernel and fused_chunk_fwd
causing redundant execution when fused_chunk_fwd calls the kernel; extract that
loop into a single helper (e.g., normalize_num_steps or adjust_num_steps) and
replace the duplicated blocks in both fused_chunk_fwd_kernel and fused_chunk_fwd
with a call to that helper (or, if you prefer the kernel to assume validated
input, document that fused_chunk_fwd must call the helper and remove the loop
from fused_chunk_fwd_kernel). Ensure the helper returns the adjusted num_steps
(and any dependent values like NT_OUTER/WINDOW) so both functions use the same
logic and avoid double-adjustment.
tops/ops/simple_gla/fused_chunk.py (1)

286-296: Consider matching dh_states dtype to reduce memory/conversion overhead.

dh_states is created with jnp.float32 dtype (line 294), while other backward outputs like dv use v.dtype. If v.dtype is already float32, this is fine. However, if inputs are bfloat16, the materialized dh_states tensor (B, H, NT, K, V) will be 2x larger than necessary in memory.

This is likely intentional for numerical stability in the backward pass, but worth confirming this tradeoff is acceptable given the memory savings goal of the two-pass split.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/fused_chunk.py` around lines 286 - 296, The dh_states
output is hardcoded to jnp.float32 in pass1_out_shapes which can bloat memory
when inputs (v) are bfloat16; change the dh_states dtype to match v.dtype by
using v.dtype (i.e., jax.ShapeDtypeStruct((B, H, NT, K, V), v.dtype)) in
pass1_out_shapes (and similarly ensure any related uses in pass1_out_specs or
downstream code that assume float32 are updated or explicitly guarded), or if
you must keep float32 for numerical-stability reasons, add a clear code comment
explaining this trade-off so the choice is explicit.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tops/ops/common/fused_chunk.py`:
- Around line 212-216: The num_steps auto-adjustment loop is duplicated in
fused_chunk_fwd_kernel and fused_chunk_fwd causing redundant execution when
fused_chunk_fwd calls the kernel; extract that loop into a single helper (e.g.,
normalize_num_steps or adjust_num_steps) and replace the duplicated blocks in
both fused_chunk_fwd_kernel and fused_chunk_fwd with a call to that helper (or,
if you prefer the kernel to assume validated input, document that
fused_chunk_fwd must call the helper and remove the loop from
fused_chunk_fwd_kernel). Ensure the helper returns the adjusted num_steps (and
any dependent values like NT_OUTER/WINDOW) so both functions use the same logic
and avoid double-adjustment.

In `@tops/ops/simple_gla/fused_chunk.py`:
- Around line 286-296: The dh_states output is hardcoded to jnp.float32 in
pass1_out_shapes which can bloat memory when inputs (v) are bfloat16; change the
dh_states dtype to match v.dtype by using v.dtype (i.e.,
jax.ShapeDtypeStruct((B, H, NT, K, V), v.dtype)) in pass1_out_shapes (and
similarly ensure any related uses in pass1_out_specs or downstream code that
assume float32 are updated or explicitly guarded), or if you must keep float32
for numerical-stability reasons, add a clear code comment explaining this
trade-off so the choice is explicit.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9e2c99cf-9985-49f8-87d2-36bba1d8c01f

📥 Commits

Reviewing files that changed from the base of the PR and between d76b28c and 8b5fc07.

📒 Files selected for processing (2)
  • tops/ops/common/fused_chunk.py
  • tops/ops/simple_gla/fused_chunk.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant