Skip to content

feat(gmm): JIT-compilable JAX GMM for TPU with forward/backward and tokamax comparison#161

Open
sii-xinglong wants to merge 15 commits intomainfrom
worktree-wise-swinging-kahan
Open

feat(gmm): JIT-compilable JAX GMM for TPU with forward/backward and tokamax comparison#161
sii-xinglong wants to merge 15 commits intomainfrom
worktree-wise-swinging-kahan

Conversation

@sii-xinglong
Copy link
Copy Markdown
Contributor

@sii-xinglong sii-xinglong commented Apr 6, 2026

Summary

  • Add CPU reference implementations (gmm_ref, tgmm_ref) with bf16 mul / f32 accumulation
  • Add JIT-compilable gmm() and tgmm() using lax.scan + dynamic_slice that run on TPU via XLA compilation
  • Wire custom_vjp for differentiable GMM: forward via GMM, backward dlhs via GMM (transposed), backward drhs via TGMM
  • Add tokamax>=0.0.12 as optional TPU dependency for comparison testing
  • Write 16 parametrized comparison tests (12 forward + 4 backward) comparing JAX GMM vs tokamax GMM in bf16

Key Design Decisions

  • Algorithm: lax.scan iterates over groups sequentially; dynamic_slice with padded arrays ([2m, k]) avoids OOB clamping; lax.dot with preferred_element_type=jnp.float32 for bf16 mul / f32 accumulation
  • Backward pass: dlhs = gmm(grad, rhs, group_sizes, not transpose_rhs), drhs = tgmm(lhs, grad, group_sizes) (or swapped args when transpose_rhs=True)
  • Tokamax comparison: Tests use pytest.importorskip + jax.default_backend() != "tpu" skip — gracefully skip on non-TPU environments
  • Compute trade-off: O(G * m * k * n) vs optimal O(m * k * n) — acceptable for correctness reference; tokamax is the high-performance path

Files

  • tops/ops/gmm/gmm.py_gmm_impl, _tgmm_impl, gmm (custom_vjp), tgmm
  • tops/ops/gmm/__init__.py — Public API exports
  • tops/cpu/ops/gmm/naive.py — CPU reference implementations
  • tests/ops/gmm/test_cpu_ref.py — 15 tests (CPU ref + JAX impl + gradients)
  • tests/ops/gmm/test_gmm_vs_tokamax.py — 16 parametrized tokamax comparison tests
  • pyproject.tomltokamax>=0.0.12 in [tpu] extras

Test plan

  • CPU reference tests (4 gmm_ref + 2 tgmm_ref) — verify per-group matmul correctness
  • JAX impl tests (3 gmm + 2 tgmm) — verify JIT-compilable impl matches CPU reference
  • Gradient tests (4 tests) — verify custom_vjp produces correct-shape, non-NaN gradients for lhs/rhs with and without transpose_rhs
  • Tokamax comparison tests (12 forward + 4 backward) — skip without tokamax/TPU, compare bf16 numerical accuracy

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Added grouped matrix multiplication with CPU reference and JIT/TPU-optimized implementations, including transpose mode and gradient support.
  • Tests

    • Added comprehensive CPU tests and TPU-focused tests (compare against optional accelerator kernel) covering single/multi-group cases, transpose behavior, empty-group handling, JIT, and gradient checks.
  • Chores

    • Optional dependency for the TPU kernel added to project metadata.

sii-xinglong and others added 7 commits April 6, 2026 15:36
Phase 1 design for grouped matrix multiplication kernel that replaces
tokamax/qwix backends. BF16 with float32 accumulation, custom_vjp for
full differentiability.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
5-task TDD implementation plan covering CPU reference, group metadata,
GMM/TGMM kernels, custom_vjp, and public API with complete code.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add two core Pallas kernel functions for grouped matrix multiplication:
- gmm_forward: forward pass computing per-group matmuls with support
  for transposed rhs and data-dependent indexing via scalar prefetch
- tgmm_forward: transposed GMM for weight gradients, accumulating
  lhs^T @ rhs per group with proper boundary masking

Both kernels use PrefetchScalarGridSpec with group metadata for
efficient ragged-batch processing. Tests validate accuracy against
CPU reference implementations across uniform, skewed, single-group,
and with-empty-group distributions using PALLAS_INTERPRET mode.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add @jax.jit with static_argnames to gmm_forward and tgmm_forward
- Remove int() casts on num_active_tiles (traced values under JIT)
- Use pl.num_programs(2) in tgmm kernel instead of captured constant

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Wire up custom_vjp for the gmm function so it supports jax.grad:
- dlhs computed via gmm_forward with transposed rhs
- drhs computed via tgmm_forward
Export gmm, gmm_forward, tgmm_forward from tops.ops.gmm and
add gmm to tops.ops public API. Add 3 gradient test cases
comparing Pallas backward against a differentiable JAX reference.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 6, 2026

📝 Walkthrough

Walkthrough

Adds a grouped matrix multiplication feature: CPU reference implementations (gmm_ref, tgmm_ref), a JIT/TPU-oriented implementation with custom VJP (gmm, tgmm), new tests (CPU and TPU/tokamax comparisons), module exports, and a small pyproject optional-dependency update.

Changes

Cohort / File(s) Summary
CPU reference implementations
tops/cpu/ops/gmm/naive.py, tops/cpu/ops/gmm/__init__.py
Introduces gmm_ref and tgmm_ref decorated CPU-reference functions that validate shapes, iterate groups using group_sizes, perform per-group dot products (optional RHS transpose), skip empty groups, and export them via __all__.
JIT / TPU implementation & autodiff
tops/ops/gmm/gmm.py, tops/ops/gmm/__init__.py
Adds scan-based _gmm_impl and _tgmm_impl using lax.scan/dynamic_slice, padding/masking for group rows, public gmm and tgmm functions, and a jax.custom_vjp for gmm that computes gradients via the tgmm kernel.
CPU-focused tests
tests/ops/gmm/test_cpu_ref.py
Adds pytest suite validating gmm_ref/tgmm_ref and that tops.ops.gmm (JIT-able) matches CPU refs across single/multi-group, empty-group, transpose-RHS, and gradient shape/sanity checks.
TPU/tokamax tests
tests/ops/gmm/test_gmm_vs_tokamax.py
Adds TPU-only tests that compare tops.ops.gmm.gmm (bf16 TPU cases) to tokamax kernel outputs and to CPU reference; includes forward and selected backward gradient comparisons (conditional on TPU backend and tokamax import).
Build metadata
pyproject.toml
Updates optional tpu extra to add tokamax>=0.0.12.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰 In grouped rows we hop and play,
Dots and scans compile the day.
CPU checks and TPU friends,
Gradients traced where padding ends.
Hooray for matrices, marching away! 🥕✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 59.02% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main feature addition: JIT-compilable JAX GMM implementation for TPU with forward/backward passes and tokamax comparison, which aligns with the changeset's core focus.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch worktree-wise-swinging-kahan

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a Pallas TPU kernel for Grouped Matrix Multiplication (GMM) and its transpose (TGMM), specifically designed to support Mixture-of-Experts (MoE) layers. The implementation includes logic for generating CSR-style group metadata to handle ragged groups, a differentiable wrapper using jax.custom_vjp, and a JAX-based CPU reference for testing. Feedback was provided to simplify the conditional logic for handling transposed RHS within the kernel to improve readability.

Comment on lines +158 to +161
if transpose_rhs:
dims = ((1,), (1,)), ((), ())
else:
dims = ((1,), (0,)), ((), ())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The conditional logic for transpose_rhs inside the kernel can be simplified to avoid redundant if checks or complex dims tuples, improving readability and potentially performance.

    dims = (((1,), (1,)), ((), ())) if transpose_rhs else (((1,), (0,)), ((), ()))

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (6)
tops/cpu/ops/gmm/naive.py (1)

9-54: Consider adding dimension compatibility validation for k.

The function validates tensor ranks but doesn't verify that lhs.shape[1] (k) matches rhs.shape[1] (when transpose_rhs=False) or rhs.shape[2] (when transpose_rhs=True). A mismatch would cause a runtime error during matmul, but an explicit assertion would provide a clearer error message.

As per coding guidelines: "All public functions must enforce strict constraints and validation on the shape and types of input variables using assert instructions."

Proposed validation
  assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D"
  assert rhs.ndim == 3, f"rhs must be 3D, got {rhs.ndim}D"
  assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D"
+  expected_k = rhs.shape[2] if transpose_rhs else rhs.shape[1]
+  assert lhs.shape[1] == expected_k, (
+    f"lhs k-dim ({lhs.shape[1]}) must match rhs {'n' if transpose_rhs else 'k'}-dim ({expected_k})"
+  )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/cpu/ops/gmm/naive.py` around lines 9 - 54, Add an explicit shape
compatibility assertion in gmm_ref: verify k = lhs.shape[1] matches rhs.shape[1]
when transpose_rhs is False or rhs.shape[2] when transpose_rhs is True, and
raise a clear assertion error mentioning gmm_ref, lhs, rhs, and transpose_rhs;
update the checks near the existing rank assertions so callers get a
deterministic, informative message instead of a downstream matmul runtime error.
tests/ops/gmm/test_cpu_ref.py (1)

72-94: Consider adding an empty group test for tgmm_ref.

TestGmmRef includes test_empty_group to verify empty group handling, but TestTgmmRef lacks a similar test. Since tgmm_forward uses visit_empty_groups=True, testing that tgmm_ref correctly produces a zero output tensor for empty groups would improve coverage.

Example test case
def test_empty_group(self):
  """Empty group produces zero output for that group."""
  lhs = jnp.array([[1.0], [2.0]], dtype=jnp.float32)
  rhs = jnp.array([[3.0], [4.0]], dtype=jnp.float32)
  gs = jnp.array([0, 2], dtype=jnp.int32)
  out = tgmm_ref(lhs, rhs, gs)
  # Group 0: empty -> zeros
  # Group 1: [1,2]^T @ [3,4] = [[1*3+2*4]] = [[11]]
  expected = jnp.array([[[0.0]], [[11.0]]], dtype=jnp.float32)
  np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_cpu_ref.py` around lines 72 - 94, Add a test_empty_group
to TestTgmmRef that mirrors TestGmmRef's empty-group case: call tgmm_ref with
group ids that include an empty group index (e.g., gs = jnp.array([0,2])),
assert that the output for the empty group is a zero tensor and other groups
produce the correct outer-product results; locate TestTgmmRef and add a new
method test_empty_group that builds lhs, rhs, gs, calls tgmm_ref and uses
np.testing.assert_allclose to compare against the expected array containing
zeros for the empty group.
tests/ops/gmm/test_metadata.py (1)

36-41: Prefix unused unpacked variables with underscore.

Static analysis flagged several unused unpacked variables. Prefixing them with _ follows Python convention and silences the warnings.

Proposed fixes
  def test_uniform_groups_multi_tile(self):
-   (offsets, gids, mids), num_tiles = make_group_metadata(
+   (_offsets, gids, mids), num_tiles = make_group_metadata(
      group_sizes=gs, m=512, tm=128
    )

  def test_shared_tile_at_boundary(self):
-   (offsets, gids, mids), num_tiles = make_group_metadata(
+   (_offsets, gids, mids), num_tiles = make_group_metadata(
      group_sizes=gs, m=128, tm=128
    )

  def test_empty_group(self):
-   (offsets, gids, mids), num_tiles = make_group_metadata(
+   (_offsets, gids, _mids), num_tiles = make_group_metadata(
      group_sizes=gs, m=128, tm=128
    )

  def test_visit_empty_groups(self):
-   (offsets, gids, mids), num_tiles = make_group_metadata(
+   (_offsets, gids, _mids), num_tiles = make_group_metadata(
      group_sizes=gs, m=128, tm=128, visit_empty_groups=True
    )

Also applies to: 48-54, 59-63, 68-72

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_metadata.py` around lines 36 - 41, The test unpacks
(offsets, gids, mids), num_tiles = make_group_metadata(...) but some of those
unpacked variables are unused; update the test to prefix unused variables with
an underscore (e.g., use _offsets, _gids, _mids or _offsets where appropriate)
so static analysis warnings are silenced; apply the same change for the other
unpack patterns in this file (the blocks around lines 48-54, 59-63, 68-72) while
keeping references to make_group_metadata and num_tiles intact.
tops/ops/gmm/metadata.py (1)

15-48: Consider adding group_sizes dtype validation.

The function validates group_sizes.ndim but not its dtype. The callers (gmm_forward, tgmm_forward) validate group_sizes.dtype == jnp.int32, but adding validation here would provide defense-in-depth and clearer error messages at the source.

Proposed validation
  assert group_sizes.ndim == 1, "group_sizes must be 1-D"
+  assert group_sizes.dtype == jnp.int32, f"group_sizes must be int32, got {group_sizes.dtype}"
  assert m > 0, "m must be positive"
  assert tm > 0, "tm must be positive"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/metadata.py` around lines 15 - 48, Add a dtype check for
group_sizes inside make_group_metadata to ensure defensive validation: after the
existing assert that group_sizes.ndim == 1, assert that group_sizes.dtype ==
jnp.int32 (or jnp.int32 equivalent) and raise a clear error message like
"group_sizes must be int32" so callers (e.g., gmm_forward/tgmm_forward) get
consistent validation at the source; keep the check in make_group_metadata to
avoid silent type issues downstream and mirror the callers' expectations.
docs/plans/2026-04-06-gmm-kernel-impl.md (2)

158-166: Note: _make_gmm_inputs helper defined in plan but not used in test_cpu_ref.py.

The plan defines a _make_gmm_inputs helper function for generating random test inputs, but the actual tests/ops/gmm/test_cpu_ref.py doesn't include or use this helper. This is fine since the test file uses inline test data for clarity, but the plan could be updated to match if desired.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/plans/2026-04-06-gmm-kernel-impl.md` around lines 158 - 166, The plan
includes a helper _make_gmm_inputs to generate random GMM test data but the
actual test file test_cpu_ref.py doesn't import or use it; either remove or
update the plan to match the test, or modify test_cpu_ref.py to import and use
_make_gmm_inputs. Locate the helper definition named _make_gmm_inputs in the
plan and the test module test_cpu_ref.py and either (a) delete or comment out
the unused helper from docs/plans/2026-04-06-gmm-kernel-impl.md, or (b) update
test_cpu_ref.py to call _make_gmm_inputs(m, k, n, num_groups, group_sizes, seed,
dtype) and replace the inline test data with the generated lhs/rhs/gs so both
the plan and test stay consistent.

1-13: Minor markdown formatting issues.

Static analysis flagged two markdown issues:

  1. Line 13: Heading jumps from # (h1) to ### (h3), skipping h2
  2. Line 1290: Code block lacks a language specifier
Suggested fixes
-### Task 1: CPU Reference Implementation
+## Task 1: CPU Reference Implementation

At line 1290:

-```
+```text
 tops/ops/gmm/metadata.py      <- no internal deps
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/plans/2026-04-06-gmm-kernel-impl.md` around lines 1 - 13, Fix the
markdown heading and code fence: change the third-level heading that immediately
follows the top-level "# GMM (Grouped Matrix Multiplication) Implementation
Plan" to a second-level heading (replace the `###` with `##`) so headings don’t
skip h2, and add a language specifier to the fenced code block that currently
shows "```text tops/ops/gmm/metadata.py ..." (update the backticks to include a
language like ```text or ```bash) so the block has a proper language tag. Ensure
the edits touch the top-level heading area and the code fence near the large
code block referenced in the doc.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/plans/2026-04-06-gmm-kernel-design.md`:
- Around line 67-73: The docstring in the design doc is inconsistent with the
actual implementation in tops/ops/gmm/metadata.py: update the description of
make_group_metadata to match the code (or vice versa). Specifically, state that
group_offsets is the plain cumulative sum of unpadded group_sizes (not
tm-rounded offsets) and that m_tile_ids are global row-tile ids (not per-group
offsets); keep group_ids semantics as in the code. Reference
make_group_metadata, group_offsets, group_ids, and m_tile_ids when making the
doc edit so the doc and tops/ops/gmm/metadata.py agree.
- Around line 14-16: The fenced code examples (e.g., the block containing the
expression "out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i, :, :]" and
the other similar fenced blocks) lack language identifiers causing markdownlint
MD040; update each triple-backtick fence to include an explicit language tag
such as ```python (or ```text if it's plain output) so the blocks are properly
classified and rendered; ensure you update all occurrences mentioned (the blocks
around the example expression and the other fenced regions referenced) to use
the same convention.

In `@tops/ops/gmm/gmm.py`:
- Around line 221-240: Add a centralized validator (e.g.,
validate_group_partition) and call it from both gmm entry points before any
kernel/tiling work; the validator should assert group_sizes.ndim == 1, all
entries are integer and >= 0, group_sizes.sum() == m, and in gmm_forward also
assert group_sizes.shape[0] == rhs.shape[0] (or conversely that num_groups ==
group_sizes.shape[0]) so we fail fast on invalid partitions; place calls to this
validator alongside the existing asserts (near where lhs/rhs shapes and k are
validated) and keep using assert-style checks to satisfy the public-function
validation requirement.
- Around line 628-634: The backward path builds drhs by calling
tgmm_forward(lhs, grad, ...) which returns shape [num_groups, k, n] but when
transpose_rhs=True the primal rhs is [num_groups, n, k], so before casting drhs
to rhs.dtype you must transpose the last two axes (swap axis 1 and 2) to produce
[num_groups, n, k]; update the code around the drhs assignment in gmm.py to do
this conditional transpose when transpose_rhs is True and then
astype(rhs.dtype), and add a rectangular test case exercising transpose_rhs=True
(non-square n!=k) so the transposed gradient layout is validated.

---

Nitpick comments:
In `@docs/plans/2026-04-06-gmm-kernel-impl.md`:
- Around line 158-166: The plan includes a helper _make_gmm_inputs to generate
random GMM test data but the actual test file test_cpu_ref.py doesn't import or
use it; either remove or update the plan to match the test, or modify
test_cpu_ref.py to import and use _make_gmm_inputs. Locate the helper definition
named _make_gmm_inputs in the plan and the test module test_cpu_ref.py and
either (a) delete or comment out the unused helper from
docs/plans/2026-04-06-gmm-kernel-impl.md, or (b) update test_cpu_ref.py to call
_make_gmm_inputs(m, k, n, num_groups, group_sizes, seed, dtype) and replace the
inline test data with the generated lhs/rhs/gs so both the plan and test stay
consistent.
- Around line 1-13: Fix the markdown heading and code fence: change the
third-level heading that immediately follows the top-level "# GMM (Grouped
Matrix Multiplication) Implementation Plan" to a second-level heading (replace
the `###` with `##`) so headings don’t skip h2, and add a language specifier to
the fenced code block that currently shows "```text tops/ops/gmm/metadata.py
..." (update the backticks to include a language like ```text or ```bash) so the
block has a proper language tag. Ensure the edits touch the top-level heading
area and the code fence near the large code block referenced in the doc.

In `@tests/ops/gmm/test_cpu_ref.py`:
- Around line 72-94: Add a test_empty_group to TestTgmmRef that mirrors
TestGmmRef's empty-group case: call tgmm_ref with group ids that include an
empty group index (e.g., gs = jnp.array([0,2])), assert that the output for the
empty group is a zero tensor and other groups produce the correct outer-product
results; locate TestTgmmRef and add a new method test_empty_group that builds
lhs, rhs, gs, calls tgmm_ref and uses np.testing.assert_allclose to compare
against the expected array containing zeros for the empty group.

In `@tests/ops/gmm/test_metadata.py`:
- Around line 36-41: The test unpacks (offsets, gids, mids), num_tiles =
make_group_metadata(...) but some of those unpacked variables are unused; update
the test to prefix unused variables with an underscore (e.g., use _offsets,
_gids, _mids or _offsets where appropriate) so static analysis warnings are
silenced; apply the same change for the other unpack patterns in this file (the
blocks around lines 48-54, 59-63, 68-72) while keeping references to
make_group_metadata and num_tiles intact.

In `@tops/cpu/ops/gmm/naive.py`:
- Around line 9-54: Add an explicit shape compatibility assertion in gmm_ref:
verify k = lhs.shape[1] matches rhs.shape[1] when transpose_rhs is False or
rhs.shape[2] when transpose_rhs is True, and raise a clear assertion error
mentioning gmm_ref, lhs, rhs, and transpose_rhs; update the checks near the
existing rank assertions so callers get a deterministic, informative message
instead of a downstream matmul runtime error.

In `@tops/ops/gmm/metadata.py`:
- Around line 15-48: Add a dtype check for group_sizes inside
make_group_metadata to ensure defensive validation: after the existing assert
that group_sizes.ndim == 1, assert that group_sizes.dtype == jnp.int32 (or
jnp.int32 equivalent) and raise a clear error message like "group_sizes must be
int32" so callers (e.g., gmm_forward/tgmm_forward) get consistent validation at
the source; keep the check in make_group_metadata to avoid silent type issues
downstream and mirror the callers' expectations.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 65d1de2e-80e9-49c6-a742-fd46a7f97c8e

📥 Commits

Reviewing files that changed from the base of the PR and between 90ad718 and 9e2ffa5.

📒 Files selected for processing (12)
  • docs/plans/2026-04-06-gmm-kernel-design.md
  • docs/plans/2026-04-06-gmm-kernel-impl.md
  • tests/ops/gmm/__init__.py
  • tests/ops/gmm/test_cpu_ref.py
  • tests/ops/gmm/test_gmm_tpu.py
  • tests/ops/gmm/test_metadata.py
  • tops/cpu/ops/gmm/__init__.py
  • tops/cpu/ops/gmm/naive.py
  • tops/ops/__init__.py
  • tops/ops/gmm/__init__.py
  • tops/ops/gmm/gmm.py
  • tops/ops/gmm/metadata.py

Comment on lines +14 to +16
```
out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i, :, :]
```
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add language identifiers to these fenced blocks.

markdownlint is already flagging these examples under MD040. Adding text/python language tags will clear the warning and improve rendering.

Also applies to: 19-21, 58-61, 92-95, 110-123

🧰 Tools
🪛 markdownlint-cli2 (0.22.0)

[warning] 14-14: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/plans/2026-04-06-gmm-kernel-design.md` around lines 14 - 16, The fenced
code examples (e.g., the block containing the expression "out[start_i:end_i, :]
= lhs[start_i:end_i, :] @ rhs[i, :, :]" and the other similar fenced blocks)
lack language identifiers causing markdownlint MD040; update each
triple-backtick fence to include an explicit language tag such as ```python (or
```text if it's plain output) so the blocks are properly classified and
rendered; ensure you update all occurrences mentioned (the blocks around the
example expression and the other fenced regions referenced) to use the same
convention.

Comment on lines +67 to +73
### Group Metadata (computed on host, passed via scalar prefetch)

`make_group_metadata(group_sizes, m, tm)` produces:
- `group_offsets`: CSR-style cumulative row offsets, rounded to tm boundaries
- `group_ids`: maps each active m-tile index to its group
- `m_tile_ids`: maps each active m-tile index to its row-tile offset within the group

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

The metadata contract here is out of sync with make_group_metadata.

In tops/ops/gmm/metadata.py, group_offsets stays as the plain cumulative sum of unpadded group_sizes, and m_tile_ids are global row-tile ids. Describing them here as tm-rounded offsets and per-group tile offsets documents a different scheduler than the one in code.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/plans/2026-04-06-gmm-kernel-design.md` around lines 67 - 73, The
docstring in the design doc is inconsistent with the actual implementation in
tops/ops/gmm/metadata.py: update the description of make_group_metadata to match
the code (or vice versa). Specifically, state that group_offsets is the plain
cumulative sum of unpadded group_sizes (not tm-rounded offsets) and that
m_tile_ids are global row-tile ids (not per-group offsets); keep group_ids
semantics as in the code. Reference make_group_metadata, group_offsets,
group_ids, and m_tile_ids when making the doc edit so the doc and
tops/ops/gmm/metadata.py agree.

Comment on lines +221 to +240
assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D"
assert rhs.ndim == 3, f"rhs must be 3D, got {rhs.ndim}D"
assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D"

if interpret is None:
interpret = get_interpret()
if preferred_element_type is None:
preferred_element_type = lhs.dtype

m, k_lhs = lhs.shape
num_groups = rhs.shape[0]

if transpose_rhs:
n, k_rhs = rhs.shape[1], rhs.shape[2]
else:
k_rhs, n = rhs.shape[1], rhs.shape[2]
assert k_lhs == k_rhs, f"lhs K ({k_lhs}) must match rhs K ({k_rhs})"

k = k_lhs
tm, tk, tn = _validate_tiling(tiling, m, k, n)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate the group partition before building metadata.

Both entry points only validate rank/K compatibility today. If group_sizes is non-integer/negative, does not sum to m, or—gmm_forward only—doesn't match rhs.shape[0], the scheduler can leave rows unwritten or address the wrong expert buffer instead of failing fast. Please centralize those invariants in a shared validator before launching the kernels. As per coding guidelines, "All public functions must enforce strict constraints and validation on the shape and types of input variables using assert instructions (or utilities like assert_shape_or_none from tops.utils) before executing the main logic"

Also applies to: 460-476

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/gmm.py` around lines 221 - 240, Add a centralized validator
(e.g., validate_group_partition) and call it from both gmm entry points before
any kernel/tiling work; the validator should assert group_sizes.ndim == 1, all
entries are integer and >= 0, group_sizes.sum() == m, and in gmm_forward also
assert group_sizes.shape[0] == rhs.shape[0] (or conversely that num_groups ==
group_sizes.shape[0]) so we fail fast on invalid partitions; place calls to this
validator alongside the existing asserts (near where lhs/rhs shapes and k are
validated) and keep using assert-style checks to satisfy the public-function
validation requirement.

Comment on lines +628 to +634
drhs = tgmm_forward(
lhs,
grad,
group_sizes,
tiling=tiling,
preferred_element_type=preferred_element_type,
).astype(rhs.dtype)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Transpose drhs in the transpose_rhs=True backward path.

tgmm_forward(lhs, grad, ...) always returns [num_groups, k, n], but the primal rhs is [num_groups, n, k] when transpose_rhs=True. The custom VJP therefore returns a cotangent with the wrong layout here: rectangular cases will fail, and square cases silently get the transposed weight gradient. Please swap the last two axes before casting back to rhs.dtype; also add a transpose_rhs=True rectangular gradient case so this path stays covered.

💡 Suggested fix
   drhs = tgmm_forward(
     lhs,
     grad,
     group_sizes,
     tiling=tiling,
     preferred_element_type=preferred_element_type,
-  ).astype(rhs.dtype)
+  )
+  if transpose_rhs:
+    drhs = jnp.swapaxes(drhs, -1, -2)
+  drhs = drhs.astype(rhs.dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/gmm.py` around lines 628 - 634, The backward path builds drhs by
calling tgmm_forward(lhs, grad, ...) which returns shape [num_groups, k, n] but
when transpose_rhs=True the primal rhs is [num_groups, n, k], so before casting
drhs to rhs.dtype you must transpose the last two axes (swap axis 1 and 2) to
produce [num_groups, n, k]; update the code around the drhs assignment in gmm.py
to do this conditional transpose when transpose_rhs is True and then
astype(rhs.dtype), and add a rectangular test case exercising transpose_rhs=True
(non-square n!=k) so the transposed gradient layout is validated.

sii-xinglong and others added 8 commits April 6, 2026 16:52
Design and implementation plan documents are not needed in the final
codebase — the code and tests serve as the authoritative reference.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The docstring incorrectly stated rhs shape as [num_groups, k, n] when
transpose_rhs=True. The actual input shape is [num_groups, n, k], with
each slice transposed to [k, n] before matmul.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… in refs

Remove all Pallas TPU kernel code (gmm.py, metadata.py) and their tests
(test_gmm_tpu.py, test_metadata.py). Retain CPU reference implementations
as the ground truth.

Change naive refs from "cast to f32 then multiply" to "bf16 multiply with
f32 accumulation" via lax.dot(preferred_element_type=f32), matching TPU
MXU semantics. Output is now f32 directly instead of casting back to
input dtype.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Change all test inputs from float32 to bfloat16 to match the intended
usage pattern (bf16 mul + f32 accumulation). Expected values remain
f32 as the reference now outputs f32. Tolerances relaxed to 1e-2.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Implements _gmm_impl and _tgmm_impl with scan-based grouped matmul
using dynamic_slice for TPU/JIT compatibility. bf16 inputs with f32
accumulation to match TPU MXU semantics. Adds tests verifying JAX
implementations match CPU reference outputs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@sii-xinglong sii-xinglong changed the title feat(gmm): add Pallas TPU kernels for Grouped Matrix Multiplication feat(gmm): JIT-compilable JAX GMM for TPU with forward/backward and tokamax comparison Apr 6, 2026
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (4)
tests/ops/gmm/test_gmm_vs_tokamax.py (3)

59-67: Same issue: function call in default argument.

Apply the same fix as _make_inputs above.

♻️ Proposed fix
-def _make_inputs_transposed(case, key=jax.random.PRNGKey(42)):
+def _make_inputs_transposed(case, key=None):
+  if key is None:
+    key = jax.random.PRNGKey(42)
   k1, k2 = jax.random.split(key)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm_vs_tokamax.py` around lines 59 - 67, The function
_make_inputs_transposed uses a function call jax.random.PRNGKey(42) as a default
argument which is evaluated at import time; change the signature to accept
key=None and inside _make_inputs_transposed set key = jax.random.PRNGKey(42) if
key is None, then proceed to call jax.random.split(key) and generate lhs, rhs,
and gs as before (same pattern used in _make_inputs). Update references to the
parameter name if needed but keep the function name and behavior unchanged
otherwise.

145-163: Consider expanding backward test coverage.

The backward tests only run on CASES[:2], leaving the more complex cases (8-group, varying group sizes) untested for gradients. The comment from past reviews also noted missing transpose_rhs=True gradient tests. Consider adding at least one transpose_rhs=True backward case to validate the transposed gradient path, especially given the past review flagging a potential layout issue in that path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm_vs_tokamax.py` around lines 145 - 163, The
test_grad_lhs currently only paramaterizes CASES[:2], missing complex scenarios
and the transpose_rhs=True path; update the parametrize on test_grad_lhs to
include additional CASES (at least the 8-group / varying group-size cases) and
add a separate parametrized test or include a case where _call_tokamax_gmm and
gmm are invoked with transpose_rhs=True so the backward/transposed gradient path
is exercised (ensure the new tests call jax.grad over tokamax_loss and jax_loss
as in test_grad_lhs to compare gradients).

49-56: Avoid function call in default argument.

jax.random.PRNGKey(42) is evaluated once at function definition time, not at call time. While deterministic here, this pattern can cause subtle bugs if the key were ever mutated or if the function is called multiple times expecting fresh state. Move the default inside the function body.

♻️ Proposed fix
-def _make_inputs(case, key=jax.random.PRNGKey(42)):
+def _make_inputs(case, key=None):
+  if key is None:
+    key = jax.random.PRNGKey(42)
   k1, k2 = jax.random.split(key)
   lhs = jax.random.normal(k1, (case["m"], case["k"]), dtype=jnp.bfloat16)
   rhs = jax.random.normal(
     k2, (case["num_groups"], case["k"], case["n"]), dtype=jnp.bfloat16
   )
   gs = jnp.array(case["group_sizes"], dtype=jnp.int32)
   return lhs, rhs, gs
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm_vs_tokamax.py` around lines 49 - 56, The _make_inputs
function currently uses jax.random.PRNGKey(42) as a default argument which is
evaluated at definition time; change the signature to accept key=None (or omit
the default) and inside the function do if key is None: key =
jax.random.PRNGKey(42) before splitting; keep the rest of the logic (k1, k2 =
jax.random.split(key), lhs/rhs/gs creation) unchanged so each call gets a fresh
key when desired and avoids evaluating the PRNGKey at import time.
tops/ops/gmm/gmm.py (1)

41-41: Minor: offset array creation could use explicit dtype.

The jnp.zeros(1, dtype=jnp.int32) is good, but jnp.cumsum(group_sizes) inherits dtype from group_sizes. This is fine assuming group_sizes is always int32, but an explicit cast could guard against unexpected input dtypes.

♻️ Optional: explicit dtype for cumsum
-  offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), jnp.cumsum(group_sizes)])
+  offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), jnp.cumsum(group_sizes, dtype=jnp.int32)])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/gmm.py` at line 41, The offsets construction uses
jnp.cumsum(group_sizes) which may inherit an unintended dtype; update the
offsets creation (the line building offsets that concatenates jnp.zeros and
jnp.cumsum) to ensure jnp.cumsum(group_sizes) is explicitly cast to the desired
dtype (e.g., jnp.int32) before concatenation so both parts share the same dtype
and avoid dtype-mismatch issues with the offsets variable.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/ops/gmm/test_gmm_vs_tokamax.py`:
- Around line 59-67: The function _make_inputs_transposed uses a function call
jax.random.PRNGKey(42) as a default argument which is evaluated at import time;
change the signature to accept key=None and inside _make_inputs_transposed set
key = jax.random.PRNGKey(42) if key is None, then proceed to call
jax.random.split(key) and generate lhs, rhs, and gs as before (same pattern used
in _make_inputs). Update references to the parameter name if needed but keep the
function name and behavior unchanged otherwise.
- Around line 145-163: The test_grad_lhs currently only paramaterizes CASES[:2],
missing complex scenarios and the transpose_rhs=True path; update the
parametrize on test_grad_lhs to include additional CASES (at least the 8-group /
varying group-size cases) and add a separate parametrized test or include a case
where _call_tokamax_gmm and gmm are invoked with transpose_rhs=True so the
backward/transposed gradient path is exercised (ensure the new tests call
jax.grad over tokamax_loss and jax_loss as in test_grad_lhs to compare
gradients).
- Around line 49-56: The _make_inputs function currently uses
jax.random.PRNGKey(42) as a default argument which is evaluated at definition
time; change the signature to accept key=None (or omit the default) and inside
the function do if key is None: key = jax.random.PRNGKey(42) before splitting;
keep the rest of the logic (k1, k2 = jax.random.split(key), lhs/rhs/gs creation)
unchanged so each call gets a fresh key when desired and avoids evaluating the
PRNGKey at import time.

In `@tops/ops/gmm/gmm.py`:
- Line 41: The offsets construction uses jnp.cumsum(group_sizes) which may
inherit an unintended dtype; update the offsets creation (the line building
offsets that concatenates jnp.zeros and jnp.cumsum) to ensure
jnp.cumsum(group_sizes) is explicitly cast to the desired dtype (e.g.,
jnp.int32) before concatenation so both parts share the same dtype and avoid
dtype-mismatch issues with the offsets variable.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 27d82765-d49b-4c95-b472-1ff6830d93d4

📥 Commits

Reviewing files that changed from the base of the PR and between afa058e and 6d312fe.

📒 Files selected for processing (5)
  • pyproject.toml
  • tests/ops/gmm/test_cpu_ref.py
  • tests/ops/gmm/test_gmm_vs_tokamax.py
  • tops/ops/gmm/__init__.py
  • tops/ops/gmm/gmm.py
✅ Files skipped from review due to trivial changes (2)
  • pyproject.toml
  • tests/ops/gmm/test_cpu_ref.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tops/ops/gmm/init.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant