Skip to content

Feat/varlen#153

Open
pathfinder-pf wants to merge 12 commits intomainfrom
feat/varlen
Open

Feat/varlen#153
pathfinder-pf wants to merge 12 commits intomainfrom
feat/varlen

Conversation

@pathfinder-pf
Copy link
Copy Markdown
Collaborator

@pathfinder-pf pathfinder-pf commented Apr 2, 2026

benchmark
varlen vs no varlen

image image

Summary by CodeRabbit

  • New Features

    • Added device-packed variable-length sequence support for chunked GLA forward with correct per-sequence state handling and optional per-sequence gating; new varlen forward entrypoint and dispatch to handle packed inputs.
  • Bug Fixes

    • Relaxed device-side chunk-alignment requirement so packed sequences no longer require strict per-chunk divisibility.
  • Tests

    • Added comprehensive TPU-oriented tests covering varlen forward correctness, sequence isolation, and CPU parity (including single-sequence cases).

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 2, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6c0dd21f-f481-431a-bf74-a65b97633abd

📥 Commits

Reviewing files that changed from the base of the PR and between 59c1b3d and bfd155a.

📒 Files selected for processing (1)
  • tops/ops/common/chunk_o.py
💤 Files with no reviewable changes (1)
  • tops/ops/common/chunk_o.py

📝 Walkthrough

Walkthrough

Adds device-packed variable-length (cu_seqlens) support for chunked Simple GLA: new varlen hidden-state kernel and JIT entrypoint, a varlen chunked Simple GLA forward function, dispatch updates to route varlen calls, a benchmark provider entry, and TPU tests validating varlen behavior.

Changes

Cohort / File(s) Summary
Varlen hidden-state kernel
tops/ops/common/chunk_h.py
Renamed per-head decay param to g_gamma_ref; added _chunk_fwd_h_kernel_varlen and chunk_fwd_h_kernel_varlen to compute chunk→sequence mappings using cu_seqlens, reset scratch at BOS, store chunk-boundary intermediates, and optionally return final state. Added check_chunk_fwd helper.
Varlen Simple GLA forward & dispatch
tops/ops/simple_gla/chunk.py, tops/ops/simple_gla/__init__.py
Imported new varlen hidden-state kernel and added chunk_simple_gla_fwd_varlen(...) (asserts cu_seqlens_dev present, cu_seqlens_cpu None, B==1) that routes inter-chunk state via varlen kernel. Dispatcher now selects varlen kernel when cu_seqlens_dev is provided; CHUNK/FUSED_CHUNK dispatch tightened to require no cu_seqlens_dev.
Output kernel alignment change
tops/ops/common/chunk_o.py
Removed device-side assertion requiring cu_seqlens_dev entries to be divisible by chunk_size, allowing varlen cu_seqlens not aligned to chunk_size on device.
Tests (TPU)
tests/ops/simple_gla/test_chunk_simple_gla_varlen_tpu.py
New TPU test module with parametrized tests comparing varlen Pallas kernel against CPU reference, single-sequence equivalence, and sequence-isolation checks; covers optional h0 and g_gamma.
Benchmarks
benchmarks/ops/benchmark_gla.py
Added "simple_gla_chunk_varlen" provider: flattens (B, T, H, D) into (1, B*T, H, D), enforces T % chunk_size == 0 and D % 128 == 0, and builds cu_seqlens via jnp.arange(0, B*T + 1, T).

Sequence Diagram(s)

sequenceDiagram
    participant User as User / Benchmark
    participant Dispatch as simple_gla_fwd (dispatcher)
    participant Varlen as chunk_simple_gla_fwd_varlen
    participant ChunkH as chunk_fwd_h_kernel_varlen
    participant ChunkO as chunk_fwd_o
    participant Output as Output (o, ht)

    User->>Dispatch: q, k, v, cu_seqlens_dev
    Dispatch->>Varlen: route to varlen forward
    Varlen->>ChunkH: compute per-sequence h using cu_seqlens_dev
    activate ChunkH
    ChunkH->>ChunkH: reset at BOS, store chunk boundaries, write ht at EOS
    ChunkH-->>Varlen: return h (and ht)
    deactivate ChunkH
    Varlen->>ChunkO: compute o from h, q, k, v
    ChunkO-->>Output: return o
    Varlen-->>Output: attach ht if requested
    Output-->>User: (o, ht)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • 0xaskr

Poem

🐰 I hop through chunks and pack the queues,

cu_seqlens guiding all my tiny fuse.
At BOS I reset, at EOS I cheer,
Chunked and varlen — each sequence clear.
A kernel hop, a tester's grin, varlen runs begin!

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 27.59% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'Feat/varlen' is extremely vague and does not clearly communicate what the feature is or what problem it solves, using only generic naming conventions without meaningful context. Use a more descriptive title such as 'Add variable-length sequence support to GLA chunk kernels' or 'Implement varlen forward kernel for packed sequences' to clearly convey the main change.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/varlen

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds variable-length sequence support for the simple_gla forward pass on TPU, including a new Pallas kernel, benchmarks, and tests. Feedback identifies a debug print, redundant initializations, swapped index map names, and an unrealistic VMEM limit that should be corrected.

h0 = jax.random.normal(keys[4], (N, H, K, V), dtype=dtype)

cu_seqlens_dev = jnp.array(cu_seqlens, dtype=jnp.int32)
print('fix', cu_seqlens_dev)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be debug code and should be removed.

Comment on lines +906 to +927
b_h_start = jnp.zeros((BK, BV), dtype=jnp.float32)
seq_idx = jnp.array(0, dtype=jnp.int32)

i_h, i_k, i_v, i_t = pl.program_id(0), pl.program_id(1), pl.program_id(2), pl.program_id(3)

if g_gamma_ref is not None:
b_g = g_gamma_ref[i_h].astype(jnp.float32) * (jnp.arange(0, BT) + 1)
t0 = i_t * BT

seq_idx = chunk_to_seq[i_t]

bos = cu_seqlens_ref[seq_idx]
eos = cu_seqlens_ref[seq_idx + 1]
@pl.when(bos != eos)
def _():
# reset h state
@pl.when(t0 == bos)
def reset_state():
if h0_ref is not None:
scratch_ref[...] = h0_ref[seq_idx, 0].astype(scratch_ref.dtype)
else:
scratch_ref[...] = b_h_start
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The initializations of b_h_start and seq_idx at the start of the kernel are redundant. seq_idx is overwritten immediately on line 915, and b_h_start can be replaced with a literal 0.0 in the reset_state function. Removing these avoids unnecessary array creation inside the kernel and improves clarity.

    i_h, i_k, i_v, i_t = pl.program_id(0), pl.program_id(1), pl.program_id(2), pl.program_id(3)

    if g_gamma_ref is not None:
        b_g = g_gamma_ref[i_h].astype(jnp.float32) * (jnp.arange(0, BT) + 1)
    t0 = i_t * BT

    seq_idx = chunk_to_seq[i_t]

    bos = cu_seqlens_ref[seq_idx]
    eos = cu_seqlens_ref[seq_idx + 1]
    @pl.when(bos != eos)
    def _():
        # reset h state
        @pl.when(t0 == bos)
        def reset_state():
            if h0_ref is not None:
                scratch_ref[...] =  h0_ref[seq_idx, 0].astype(scratch_ref.dtype)
            else:
                scratch_ref[...] =  0.0

Comment on lines +1044 to +1048
def h_index_map(head_index, k_index, v_index, t_index):
return 0, head_index, k_index, v_index

def ht_index_map(head_index, k_index, v_index, t_index):
return 0, head_index, k_index, v_index
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The names h_index_map and ht_index_map are swapped relative to their usage in out_specs (lines 1055 and 1058). h_index_map is used for the final state (ht) and ht_index_map is used for the intermediate states (h). This is confusing and should be corrected to match the output names.

"parallel",
"arbitrary",
),
vmem_limit_bytes=128 * 1024 * 1024,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The vmem_limit_bytes is set to 128MB, which is significantly higher than the physical VMEM capacity of TPU hardware (typically 16MB or 32MB). This may cause compilation issues or inefficient memory usage. It should be reduced to a more realistic value, such as 32MB, to match the other kernels in this file.

Suggested change
vmem_limit_bytes=128 * 1024 * 1024,
vmem_limit_bytes=32 * 1024 * 1024,
References
  1. Commented-out code for frequently tuned parameters, such as vmem_limit_bytes in CompilerParams, can be retained for discoverability and ease of modification.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tops/ops/common/chunk_h.py (1)

95-96: ⚠️ Potential issue | 🟡 Minor

Duplicate function definition: check_chunk_fwd is defined twice.

There are two identical definitions of check_chunk_fwd at lines 95-96 and 966-967. The second definition at line 966 shadows the first one. This appears to be a copy-paste error.

🐛 Proposed fix: Remove the duplicate definition

Remove lines 966-967 since the function already exists at lines 95-96:

-def check_chunk_fwd(x):
-    assert x is None, "x should be None."
-
-
 # note: The precision difference between this kernel on the TPU and FLA on the GPU is 5e-2.
 `@functools.partial`(

Also applies to: 966-967

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_h.py` around lines 95 - 96, There are two identical
definitions of the function check_chunk_fwd (one at the top and one near the
end) causing the latter to shadow the former; remove the duplicate definition
(the second occurrence of check_chunk_fwd) so the file only contains a single
check_chunk_fwd function and ensure no other references rely on the duplicated
block.
tops/ops/simple_gla/__init__.py (1)

182-190: ⚠️ Potential issue | 🔴 Critical

Logic error: elif cu_seqlens_dev is not None branch is unreachable.

The condition elif cu_seqlens_dev is not None at line 187 will never execute because:

  1. If mode == SimpleGLAKernelMode.CHUNK, line 184 executes
  2. If mode == SimpleGLAKernelMode.FUSED_CHUNK, line 186 executes
  3. For any other mode value, the else at line 189 raises an exception

The cu_seqlens_dev check should be nested within the FUSED_CHUNK branch (or CHUNK branch) to work correctly.

🐛 Proposed fix
     fn = None
     if mode == SimpleGLAKernelMode.CHUNK:
-        fn = chunk_simple_gla_fwd
+        if cu_seqlens_dev is not None:
+            fn = chunk_simple_gla_fwd_varlen
+        else:
+            fn = chunk_simple_gla_fwd
     elif mode == SimpleGLAKernelMode.FUSED_CHUNK:
         fn = fused_chunk_simple_gla_fwd
-    elif cu_seqlens_dev is not None:
-        fn = chunk_simple_gla_fwd_varlen
     else:
         raise Exception(f"mode {mode} not support")

Or if varlen should apply to FUSED_CHUNK mode:

     fn = None
     if mode == SimpleGLAKernelMode.CHUNK:
         fn = chunk_simple_gla_fwd
     elif mode == SimpleGLAKernelMode.FUSED_CHUNK:
-        fn = fused_chunk_simple_gla_fwd
-    elif cu_seqlens_dev is not None:
-        fn = chunk_simple_gla_fwd_varlen
+        if cu_seqlens_dev is not None:
+            fn = chunk_simple_gla_fwd_varlen
+        else:
+            fn = fused_chunk_simple_gla_fwd
     else:
         raise Exception(f"mode {mode} not support")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/__init__.py` around lines 182 - 190, The branch checking
cu_seqlens_dev is unreachable because the mode cases return earlier; move the
cu_seqlens_dev check inside the appropriate mode branch (e.g., within the
SimpleGLAKernelMode.FUSED_CHUNK branch or CHUNK branch as intended) and select
chunk_simple_gla_fwd_varlen when cu_seqlens_dev is not None; specifically,
inside the block for SimpleGLAKernelMode.FUSED_CHUNK (or CHUNK if varlen applies
there) choose fused_chunk_simple_gla_fwd vs fused_chunk_simple_gla_fwd_varlen
(or chunk_simple_gla_fwd vs chunk_simple_gla_fwd_varlen) based on cu_seqlens_dev
to ensure the varlen implementation is reachable.
🧹 Nitpick comments (3)
tops/ops/common/chunk_h.py (2)

1055-1058: Potential confusion: h and ht BlockSpec use swapped index_map names.

The out_specs list uses ht_index_map for h (intermediate states) and h_index_map for ht (final states). While functionally equivalent since both maps are identical, the naming is confusing.

♻️ Suggested fix for clarity
-    out_specs = [pl.BlockSpec((NS, 1, BK, BV), ht_index_map)]
+    out_specs = [pl.BlockSpec((NS, 1, BK, BV), h_index_map)]
     if output_final_state:
         out_shape.append(jax.ShapeDtypeStruct(shape=(N, H, K, V), dtype=jnp.float32))
-        out_specs.append(pl.BlockSpec((N, 1, BK, BV), h_index_map))
+        out_specs.append(pl.BlockSpec((N, 1, BK, BV), ht_index_map))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_h.py` around lines 1055 - 1058, The out_specs
construction swaps the index_map names causing confusion: currently pl.BlockSpec
for the intermediate `h` state uses ht_index_map and the final `ht` state uses
h_index_map; update the calls so the intermediate state (the first pl.BlockSpec
added to out_specs) uses h_index_map and the conditional append (when
output_final_state is true) uses ht_index_map, keeping the symbols out_specs,
pl.BlockSpec, ht_index_map, h_index_map, output_final_state, and out_shape to
locate and adjust the two BlockSpec constructions accordingly.

887-901: Missing docstring for the varlen kernel function.

Per coding guidelines, all public functions must have a clear docstring explaining business semantics, tensor shapes, and dimension meanings. The _chunk_fwd_h_kernel_varlen function lacks documentation.

📝 Suggested docstring
 def _chunk_fwd_h_kernel_varlen(
     k_ref,  # [1, BT, BK]
     v_ref,  # [1, BT, BV]
     h0_ref,  # [N, 1, BK, BV]
     gk_ref,  # [1, BT, BK]
     g_gamma_ref, # [H,]
     cu_seqlens_ref,  # [num_seq+1]
     chunk_to_seq,  # [T_sum/BT]
     h_ref,  # [NS, 1, BK, BV]
     ht_ref,  # [N, 1, BK , BV]
     scratch_ref, # [BK, BV]
     *,
     BT,
     BS,
 ):
+    """Variable-length forward kernel for inter-chunk hidden state propagation.
+
+    Processes chunks from multiple packed sequences, resetting state at
+    sequence boundaries using cu_seqlens_ref to determine sequence ownership.
+
+    Grid: (H, ceil(K/BK), ceil(V/BV), T_sum/BT)
+    """
     BT, BK = k_ref.shape[1], k_ref.shape[2]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_h.py` around lines 887 - 901, Add a clear docstring to
the function _chunk_fwd_h_kernel_varlen describing its business semantics and
documenting all arguments and tensor shapes/dimensions (e.g., k_ref: [1, BT,
BK], v_ref: [1, BT, BV], h0_ref: [N, 1, BK, BV], gk_ref: [1, BT, BK],
g_gamma_ref: [H,], cu_seqlens_ref: [num_seq+1], chunk_to_seq: [T_sum/BT], h_ref:
[NS, 1, BK, BV], ht_ref: [N, 1, BK, BV], scratch_ref: [BK, BV]) and the meanings
of BT and BS; include a brief note on return/side-effects and any important
invariants (e.g., layout, contiguous requirements, dtype expectations) so
callers know how to use _chunk_fwd_h_kernel_varlen.
tops/ops/simple_gla/chunk.py (1)

522-577: Missing docstring for public chunk_simple_gla_fwd_varlen function.

Per coding guidelines, all public functions must have a clear docstring explaining business semantics, tensor shapes, and dimension meanings for inputs/outputs.

📝 Suggested docstring
 def chunk_simple_gla_fwd_varlen(
     q: jax.Array,
     k: jax.Array,
     v: jax.Array,
     *,
     g: jax.Array | None = None,
     g_gamma: jax.Array | None = None,
     scale: float | None = None,
     h0: jax.Array | None = None,
     use_ht: bool = False,
     cu_seqlens_cpu: jax.Array | None = None,
     cu_seqlens_dev: jax.Array | None = None,
     chunk_size: int = 64,
     interpret: bool | None = None,
 ) -> tuple[jax.Array, jax.Array | None]:
+    """Variable-length Simple GLA forward pass.
+
+    Computes Simple GLA for packed variable-length sequences using the
+    varlen chunk kernel. Each sequence is processed independently with
+    its own initial state, and outputs are concatenated.
+
+    Args:
+        q: Queries, shape [1, T_sum, H, K] (B must be 1).
+        k: Keys, shape [1, T_sum, H, K].
+        v: Values, shape [1, T_sum, H, V].
+        g: Per-token gate (must be None for chunk path).
+        g_gamma: Per-head log-decay, shape [H].
+        scale: Query scaling factor (default: K**-0.5).
+        h0: Initial state per sequence, shape [N, H, K, V].
+        use_ht: Whether to return final state per sequence.
+        cu_seqlens_cpu: Must be None (not supported).
+        cu_seqlens_dev: Cumulative sequence lengths, shape [N+1]. Required.
+        chunk_size: Block size (must divide all sequence lengths).
+        interpret: Use interpret mode for debugging.
+
+    Returns:
+        o: Output tensor, shape [1, T_sum, H, V].
+        ht: Final state per sequence [N, H, K, V] if use_ht else None.
+    """
     B, T, H, K, V = *q.shape, v.shape[-1]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` around lines 522 - 577, Add a clear docstring
to the public function chunk_simple_gla_fwd_varlen describing its purpose (what
it computes), the meaning and expected shapes/dtypes of each parameter (q, k, v,
optional g, g_gamma, scale, h0, use_ht, cu_seqlens_cpu, cu_seqlens_dev,
chunk_size, interpret), any preconditions/assertions (e.g. B==1, K and V
multiples of 128, T % chunk_size == 0), the shapes and semantics of return
values (o and ht), and any side-effects or errors raised; place the docstring
immediately below the def line so callers and autogenerated docs see the
business semantics and tensor dimension conventions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/ops/simple_gla/test_chunk_simple_gla_varlen_tpu.py`:
- Line 131: Remove the debug print statement print('fix', cu_seqlens_dev) from
the test; locate the occurrence of the literal print call referencing
cu_seqlens_dev in the test_chunk_simple_gla_varlen_tpu test and delete it so
tests don’t emit debug output.

In `@tops/ops/common/chunk_h.py`:
- Around line 981-1004: The function chunk_fwd_h_kernel_varlen lacks a docstring
describing tensor shapes/semantics and risks dereferencing cu_seqlens_dev when
it is None; add a clear docstring at the top of chunk_fwd_h_kernel_varlen that
documents the expected shapes and meanings of k, v, g, g_gamma, gk, gv, h0,
cu_seqlens_dev, and return values, and add an explicit guard near the start
(after check_chunk_fwd calls) that raises a descriptive TypeError or ValueError
if cu_seqlens_dev is None (or make it required by removing | None from its
annotation) so subsequent uses like len(cu_seqlens_dev) are safe.

In `@tops/ops/simple_gla/chunk.py`:
- Around line 552-576: chunk_fwd_h_varlen returns h shaped (NS, H, K, V) but
chunk_fwd_o expects h shaped (B, NT, H, K, V); in chunk_simple_gla_fwd_varlen
(where assert B == 1) reshape h to add the singleton batch dimension before
calling chunk_fwd_o (e.g., unsqueeze or view to (1, NS, H, K, V)), and pass that
reshaped tensor into chunk_fwd_o so shapes align with its reference
implementation.

---

Outside diff comments:
In `@tops/ops/common/chunk_h.py`:
- Around line 95-96: There are two identical definitions of the function
check_chunk_fwd (one at the top and one near the end) causing the latter to
shadow the former; remove the duplicate definition (the second occurrence of
check_chunk_fwd) so the file only contains a single check_chunk_fwd function and
ensure no other references rely on the duplicated block.

In `@tops/ops/simple_gla/__init__.py`:
- Around line 182-190: The branch checking cu_seqlens_dev is unreachable because
the mode cases return earlier; move the cu_seqlens_dev check inside the
appropriate mode branch (e.g., within the SimpleGLAKernelMode.FUSED_CHUNK branch
or CHUNK branch as intended) and select chunk_simple_gla_fwd_varlen when
cu_seqlens_dev is not None; specifically, inside the block for
SimpleGLAKernelMode.FUSED_CHUNK (or CHUNK if varlen applies there) choose
fused_chunk_simple_gla_fwd vs fused_chunk_simple_gla_fwd_varlen (or
chunk_simple_gla_fwd vs chunk_simple_gla_fwd_varlen) based on cu_seqlens_dev to
ensure the varlen implementation is reachable.

---

Nitpick comments:
In `@tops/ops/common/chunk_h.py`:
- Around line 1055-1058: The out_specs construction swaps the index_map names
causing confusion: currently pl.BlockSpec for the intermediate `h` state uses
ht_index_map and the final `ht` state uses h_index_map; update the calls so the
intermediate state (the first pl.BlockSpec added to out_specs) uses h_index_map
and the conditional append (when output_final_state is true) uses ht_index_map,
keeping the symbols out_specs, pl.BlockSpec, ht_index_map, h_index_map,
output_final_state, and out_shape to locate and adjust the two BlockSpec
constructions accordingly.
- Around line 887-901: Add a clear docstring to the function
_chunk_fwd_h_kernel_varlen describing its business semantics and documenting all
arguments and tensor shapes/dimensions (e.g., k_ref: [1, BT, BK], v_ref: [1, BT,
BV], h0_ref: [N, 1, BK, BV], gk_ref: [1, BT, BK], g_gamma_ref: [H,],
cu_seqlens_ref: [num_seq+1], chunk_to_seq: [T_sum/BT], h_ref: [NS, 1, BK, BV],
ht_ref: [N, 1, BK, BV], scratch_ref: [BK, BV]) and the meanings of BT and BS;
include a brief note on return/side-effects and any important invariants (e.g.,
layout, contiguous requirements, dtype expectations) so callers know how to use
_chunk_fwd_h_kernel_varlen.

In `@tops/ops/simple_gla/chunk.py`:
- Around line 522-577: Add a clear docstring to the public function
chunk_simple_gla_fwd_varlen describing its purpose (what it computes), the
meaning and expected shapes/dtypes of each parameter (q, k, v, optional g,
g_gamma, scale, h0, use_ht, cu_seqlens_cpu, cu_seqlens_dev, chunk_size,
interpret), any preconditions/assertions (e.g. B==1, K and V multiples of 128, T
% chunk_size == 0), the shapes and semantics of return values (o and ht), and
any side-effects or errors raised; place the docstring immediately below the def
line so callers and autogenerated docs see the business semantics and tensor
dimension conventions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ba559f3d-2220-4fec-a25e-593483922a46

📥 Commits

Reviewing files that changed from the base of the PR and between 4d28bec and 59e0dfb.

📒 Files selected for processing (5)
  • benchmarks/ops/benchmark_gla.py
  • tests/ops/simple_gla/test_chunk_simple_gla_varlen_tpu.py
  • tops/ops/common/chunk_h.py
  • tops/ops/simple_gla/__init__.py
  • tops/ops/simple_gla/chunk.py

Comment on lines +981 to +1004
def chunk_fwd_h_kernel_varlen(
k: jax.Array, # [B,T,H,K]
v: jax.Array, # [B,T,H,V]
g: jax.Array | None = None, # [B,T,H]
g_gamma: jax.Array | None = None, # (H,)
gk: jax.Array | None = None, # [B,T,H,K]
gv: jax.Array | None = None, # [B,T,H,V]
h0: jax.Array | None = None, # [N,H,K,V]
output_final_state: bool = False,
cu_seqlens_dev: jax.Array | None = None,
chunk_size: int = 128,
split_size: int | None = None,
states_in_fp32: bool = False,
interpret: bool = False,
):
check_chunk_fwd(g)
check_chunk_fwd(gv)
# todo: tune bk and bv for bast performance
BK = 128
BV = 128
B, T, H, K, V = *k.shape, v.shape[-1]
assert K % 128 == 0, "K % 128 must equal to 0."
assert V % 128 == 0, "V % 128 must equal to 0."
assert T % chunk_size == 0, "T mod chunk_size must equal to 0."
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing docstring and potential None dereference for cu_seqlens_dev.

  1. Per coding guidelines, public functions must have docstrings explaining tensor shapes and semantics.
  2. If cu_seqlens_dev is None, line 1017 (len(cu_seqlens_dev)) will raise TypeError. The function signature shows it as optional (| None), but it's required for this varlen kernel.
🐛 Proposed fix
+def chunk_fwd_h_kernel_varlen(
+    k: jax.Array,  # [B,T,H,K]
+    v: jax.Array,  # [B,T,H,V]
+    g: jax.Array | None = None,  # [B,T,H]
+    g_gamma: jax.Array | None = None,  # (H,)
+    gk: jax.Array | None = None,  # [B,T,H,K]
+    gv: jax.Array | None = None,  # [B,T,H,V]
+    h0: jax.Array | None = None,  # [N,H,K,V]
+    output_final_state: bool = False,
+    cu_seqlens_dev: jax.Array | None = None,
+    chunk_size: int = 128,
+    split_size: int | None = None,
+    states_in_fp32: bool = False,
+    interpret: bool = False,
+) -> tuple[jax.Array, jax.Array | None]:
+    """Variable-length inter-chunk hidden state propagation kernel.
+
+    Computes hidden state at the start of each chunk for packed variable-length
+    sequences. State is reset at sequence boundaries defined by cu_seqlens_dev.
+
+    Args:
+        k: Keys, shape [B, T, H, K] where B must be 1 for varlen.
+        v: Values, shape [B, T, H, V].
+        g: Per-token gate (must be None for this kernel).
+        g_gamma: Per-head fixed decay rate, shape [H].
+        gk: Per-K-dim gate (optional), shape [B, T, H, K].
+        gv: Per-V-dim gate (must be None for this kernel).
+        h0: Initial hidden state per sequence, shape [N, H, K, V].
+        output_final_state: Whether to return final state per sequence.
+        cu_seqlens_dev: Cumulative sequence lengths, shape [N+1]. Required.
+        chunk_size: Block size (must divide all sequence lengths).
+        split_size: Split size for intermediate state storage.
+        states_in_fp32: Store intermediate states in float32.
+        interpret: Use interpret mode for debugging.
+
+    Returns:
+        h: Hidden states at chunk boundaries, shape [NS, H, K, V].
+        ht: Final hidden state per sequence [N, H, K, V] if output_final_state else None.
+    """
     check_chunk_fwd(g)
     check_chunk_fwd(gv)
+    assert cu_seqlens_dev is not None, "cu_seqlens_dev is required for varlen kernel"
     # todo: tune bk and bv for bast performance
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_h.py` around lines 981 - 1004, The function
chunk_fwd_h_kernel_varlen lacks a docstring describing tensor shapes/semantics
and risks dereferencing cu_seqlens_dev when it is None; add a clear docstring at
the top of chunk_fwd_h_kernel_varlen that documents the expected shapes and
meanings of k, v, g, g_gamma, gk, gv, h0, cu_seqlens_dev, and return values, and
add an explicit guard near the start (after check_chunk_fwd calls) that raises a
descriptive TypeError or ValueError if cu_seqlens_dev is None (or make it
required by removing | None from its annotation) so subsequent uses like
len(cu_seqlens_dev) are safe.

Comment on lines +552 to +576
h, ht = chunk_fwd_h_varlen(
k=k,
v=v,
g=g,
g_gamma=g_gamma,
gk=None,
gv=None,
h0=h0,
output_final_state=use_ht,
states_in_fp32=False,
cu_seqlens_dev=cu_seqlens_dev,
chunk_size=chunk_size,
)
o = chunk_fwd_o(
q=q,
k=k,
v=v,
g=g,
g_gamma=g_gamma,
h=h,
scale=scale,
cu_seqlens_cpu=cu_seqlens_cpu,
cu_seqlens_dev=cu_seqlens_dev,
chunk_size=chunk_size,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check the return shape of chunk_fwd_h_kernel_varlen and expected input shape of chunk_fwd_o

echo "=== chunk_fwd_h_kernel_varlen output shape ==="
rg -n "out_shape = \[" tops/ops/common/chunk_h.py -A 5 | head -20

echo "=== chunk_fwd_o h parameter shape ==="
rg -n "def chunk_fwd_o" tops/ops/common/chunk_o.py -A 30 | head -40

Repository: primatrix/pallas-kernel

Length of output: 2103


🏁 Script executed:

#!/bin/bash
# Check chunk_fwd_o implementation and what shape it expects for h
echo "=== chunk_fwd_o full signature and initial logic ==="
rg -n "def chunk_fwd_o" tops/ops/common/chunk_o.py -A 50 | head -70

echo "=== Check if h.shape is validated or documented in chunk_fwd_o ==="
rg -n "h.shape" tops/ops/common/chunk_o.py -B 2 -A 2

Repository: primatrix/pallas-kernel

Length of output: 2912


🏁 Script executed:

#!/bin/bash
# Check the context around lines 552-576 in chunk.py to see if reshape happens
echo "=== Code context around the h assignment in chunk_simple_gla_fwd_varlen ==="
sed -n '545,580p' tops/ops/simple_gla/chunk.py | cat -n

Repository: primatrix/pallas-kernel

Length of output: 1273


🏁 Script executed:

#!/bin/bash
# Check the function signature of chunk_simple_gla_fwd_varlen to understand B, NS, NT
echo "=== chunk_simple_gla_fwd_varlen signature and early logic ==="
rg -n "def chunk_simple_gla_fwd_varlen" tops/ops/simple_gla/chunk.py -A 20 | head -30

Repository: primatrix/pallas-kernel

Length of output: 877


🏁 Script executed:

#!/bin/bash
# Find the full chunk_fwd_o implementation
echo "=== chunk_fwd_o implementation (after line 440) ==="
sed -n '440,550p' tops/ops/common/chunk_o.py | cat -n

Repository: primatrix/pallas-kernel

Length of output: 4714


🏁 Script executed:

#!/bin/bash
# Check if there's a pl.pallas_call or similar kernel invocation in chunk_fwd_o
echo "=== Search for pallas_call or kernel setup in chunk_o.py after line 440 ==="
rg -n "pl.pallas_call|pallas_call" tops/ops/common/chunk_o.py -B 3 -A 10 | head -60

Repository: primatrix/pallas-kernel

Length of output: 1638


Missing reshape of h before passing to chunk_fwd_o in the varlen path.

chunk_fwd_h_varlen returns h with shape (NS, H, K, V), but chunk_fwd_o expects h with shape (B, NT, H, K, V) per its reference implementation at tops/ops/common/chunk_o.py:388. Since assert B == 1 is enforced in chunk_simple_gla_fwd_varlen, the tensor must be reshaped to add the batch dimension before passing to chunk_fwd_o.

🐛 Proposed fix
     h, ht = chunk_fwd_h_varlen(
         k=k,
         v=v,
         g=g,
         g_gamma=g_gamma,
         gk=None,
         gv=None,
         h0=h0,
         output_final_state=use_ht,
         states_in_fp32=False,
         cu_seqlens_dev=cu_seqlens_dev,
         chunk_size=chunk_size,
     )
+    # Reshape h from (NS, H, K, V) to (B, NS, H, K, V) for chunk_fwd_o
+    h = h.reshape(B, -1, H, K, V)
     o = chunk_fwd_o(
         q=q,
         k=k,
         v=v,
         g=g,
         g_gamma=g_gamma,
         h=h,
         scale=scale,
         cu_seqlens_cpu=cu_seqlens_cpu,
         cu_seqlens_dev=cu_seqlens_dev,
         chunk_size=chunk_size,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/simple_gla/chunk.py` around lines 552 - 576, chunk_fwd_h_varlen
returns h shaped (NS, H, K, V) but chunk_fwd_o expects h shaped (B, NT, H, K,
V); in chunk_simple_gla_fwd_varlen (where assert B == 1) reshape h to add the
singleton batch dimension before calling chunk_fwd_o (e.g., unsqueeze or view to
(1, NS, H, K, V)), and pass that reshaped tensor into chunk_fwd_o so shapes
align with its reference implementation.

whz added 2 commits April 2, 2026 20:43
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/ops/common/chunk_h.py`:
- Around line 929-932: The current indexing logic (NS = T_sum // BS and s_i =
i_t // NTS inside store_fn) counts splits across the concatenated stream which
breaks when split_size != chunk_size: first chunk of a new sequence can miss its
start state across packed boundaries; either enforce and reject split_size !=
chunk_size early (validate split_size == chunk_size and raise/return a clear
error) or implement per-sequence split bookkeeping (track splits per sequence
instead of using global NS/T_sum and compute s_i per-sequence using sequence
offsets and per-seq NTS) so that store_fn (and the same pattern at the other
affected site around lines 1015-1018) indexes h_ref correctly across packed
sequence boundaries.
- Around line 995-996: check_chunk_fwd currently hard-rejects a scalar gating
argument `g` even though the varlen public entrypoint forwards `g` into
chunk_fwd_h_varlen; fix by either (A) plumbing `g` through this kernel the same
way chunk_fwd_h_kernel does: accept a nullable/optional `g` param in
check_chunk_fwd and pass it into chunk_fwd_h_varlen, preserving calling
convention and tests, or (B) move the rejection earlier in the varlen public
entrypoint (the code in tops/ops/simple_gla/chunk.py that calls
chunk_fwd_h_varlen) so that the public API refuses scalar `g` before reaching
check_chunk_fwd; pick one consistent approach and update the related callers and
docstring/comments for chunk_fwd_h_varlen/check_chunk_fwd accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9e7379ec-1fe8-40b3-a55a-a8482260d395

📥 Commits

Reviewing files that changed from the base of the PR and between 4dd2878 and 59c1b3d.

📒 Files selected for processing (2)
  • tests/ops/simple_gla/test_chunk_simple_gla_varlen_tpu.py
  • tops/ops/common/chunk_h.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/ops/simple_gla/test_chunk_simple_gla_varlen_tpu.py

Comment on lines +929 to +932
@pl.when(i_t % NTS == 0)
def store_fn():
s_i = i_t // NTS
h_ref[s_i, 0] = scratch_ref[...].astype(h_ref.dtype)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

split_size > chunk_size breaks state indexing across packed sequence boundaries.

NS = T_sum // BS and s_i = i_t // NTS count splits across the concatenated stream, not per sequence. If a new sequence starts in the middle of a split window, its first chunk never gets a stored start state. Unless packed varlen grows per-sequence split bookkeeping, this path should reject split_size != chunk_size.

🐛 Minimal guard
     BT = chunk_size
     BS = BT if split_size is None else split_size
     assert BS % BT == 0, (
         f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
     )
+    assert BS == BT, "packed varlen currently requires split_size == chunk_size"

Also applies to: 1015-1018

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_h.py` around lines 929 - 932, The current indexing
logic (NS = T_sum // BS and s_i = i_t // NTS inside store_fn) counts splits
across the concatenated stream which breaks when split_size != chunk_size: first
chunk of a new sequence can miss its start state across packed boundaries;
either enforce and reject split_size != chunk_size early (validate split_size ==
chunk_size and raise/return a clear error) or implement per-sequence split
bookkeeping (track splits per sequence instead of using global NS/T_sum and
compute s_i per-sequence using sequence offsets and per-seq NTS) so that
store_fn (and the same pattern at the other affected site around lines
1015-1018) indexes h_ref correctly across packed sequence boundaries.

Comment on lines +995 to +996
check_chunk_fwd(g)
check_chunk_fwd(gv)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

g is still accepted by the varlen API, but this kernel hard-rejects it.

tops/ops/simple_gla/chunk.py:552-564 forwards g=g into chunk_fwd_h_varlen, so any packed varlen call with scalar gating now fails on Line 995. Either plumb g through this kernel like chunk_fwd_h_kernel, or reject g earlier in the public varlen entrypoint instead of accepting it here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/common/chunk_h.py` around lines 995 - 996, check_chunk_fwd currently
hard-rejects a scalar gating argument `g` even though the varlen public
entrypoint forwards `g` into chunk_fwd_h_varlen; fix by either (A) plumbing `g`
through this kernel the same way chunk_fwd_h_kernel does: accept a
nullable/optional `g` param in check_chunk_fwd and pass it into
chunk_fwd_h_varlen, preserving calling convention and tests, or (B) move the
rejection earlier in the varlen public entrypoint (the code in
tops/ops/simple_gla/chunk.py that calls chunk_fwd_h_varlen) so that the public
API refuses scalar `g` before reaching check_chunk_fwd; pick one consistent
approach and update the related callers and docstring/comments for
chunk_fwd_h_varlen/check_chunk_fwd accordingly.

Comment on lines +1000 to +1014
B, T, H, K, V = *k.shape, v.shape[-1]
assert K % 128 == 0, "K % 128 must equal to 0."
assert V % 128 == 0, "V % 128 must equal to 0."
assert T % chunk_size == 0, "T mod chunk_size must equal to 0."

BT = chunk_size
BS = BT if split_size is None else split_size
assert BS % BT == 0, (
f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
)
# N: the actual number of sequences in the batch with either equal or variable lengths

T_sum = B * T
chunk_to_seq = _build_chunk_map(cu_seqlens=cu_seqlens_dev, T_sum=T_sum, BT=BT)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Mirror the fixed-length entrypoint’s input contract here.

This exported varlen entrypoint only checks K/V/T divisibility before calling _build_chunk_map. It still skips the upfront assert_shape / assert_shape_or_none guardrails from chunk_fwd_h_kernel, and it does not enforce the packed-varlen contract this kernel relies on (B == 1 and chunk-aligned packed spans). Without those checks, malformed inputs either fail deep inside the Pallas path or let a BT tile straddle two sequences and mix hidden state across boundaries. As per coding guidelines, public functions must enforce strict constraints and validation on the shape and types of input variables using assert instructions (or utilities like assert_shape_or_none from tops.utils) before executing the main logic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant