feat(gmm): JIT-compilable JAX GMM for TPU with forward/backward and tokamax comparison#161
feat(gmm): JIT-compilable JAX GMM for TPU with forward/backward and tokamax comparison#161sii-xinglong wants to merge 15 commits intomainfrom
Conversation
Phase 1 design for grouped matrix multiplication kernel that replaces tokamax/qwix backends. BF16 with float32 accumulation, custom_vjp for full differentiability. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
5-task TDD implementation plan covering CPU reference, group metadata, GMM/TGMM kernels, custom_vjp, and public API with complete code. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add two core Pallas kernel functions for grouped matrix multiplication: - gmm_forward: forward pass computing per-group matmuls with support for transposed rhs and data-dependent indexing via scalar prefetch - tgmm_forward: transposed GMM for weight gradients, accumulating lhs^T @ rhs per group with proper boundary masking Both kernels use PrefetchScalarGridSpec with group metadata for efficient ragged-batch processing. Tests validate accuracy against CPU reference implementations across uniform, skewed, single-group, and with-empty-group distributions using PALLAS_INTERPRET mode. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add @jax.jit with static_argnames to gmm_forward and tgmm_forward - Remove int() casts on num_active_tiles (traced values under JIT) - Use pl.num_programs(2) in tgmm kernel instead of captured constant Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Wire up custom_vjp for the gmm function so it supports jax.grad: - dlhs computed via gmm_forward with transposed rhs - drhs computed via tgmm_forward Export gmm, gmm_forward, tgmm_forward from tops.ops.gmm and add gmm to tops.ops public API. Add 3 gradient test cases comparing Pallas backward against a differentiable JAX reference. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
📝 WalkthroughWalkthroughAdds a grouped matrix multiplication feature: CPU reference implementations ( Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request implements a Pallas TPU kernel for Grouped Matrix Multiplication (GMM) and its transpose (TGMM), specifically designed to support Mixture-of-Experts (MoE) layers. The implementation includes logic for generating CSR-style group metadata to handle ragged groups, a differentiable wrapper using jax.custom_vjp, and a JAX-based CPU reference for testing. Feedback was provided to simplify the conditional logic for handling transposed RHS within the kernel to improve readability.
tops/ops/gmm/gmm.py
Outdated
| if transpose_rhs: | ||
| dims = ((1,), (1,)), ((), ()) | ||
| else: | ||
| dims = ((1,), (0,)), ((), ()) |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (6)
tops/cpu/ops/gmm/naive.py (1)
9-54: Consider adding dimension compatibility validation fork.The function validates tensor ranks but doesn't verify that
lhs.shape[1](k) matchesrhs.shape[1](whentranspose_rhs=False) orrhs.shape[2](whentranspose_rhs=True). A mismatch would cause a runtime error during matmul, but an explicit assertion would provide a clearer error message.As per coding guidelines: "All public functions must enforce strict constraints and validation on the shape and types of input variables using assert instructions."
Proposed validation
assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" assert rhs.ndim == 3, f"rhs must be 3D, got {rhs.ndim}D" assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" + expected_k = rhs.shape[2] if transpose_rhs else rhs.shape[1] + assert lhs.shape[1] == expected_k, ( + f"lhs k-dim ({lhs.shape[1]}) must match rhs {'n' if transpose_rhs else 'k'}-dim ({expected_k})" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/cpu/ops/gmm/naive.py` around lines 9 - 54, Add an explicit shape compatibility assertion in gmm_ref: verify k = lhs.shape[1] matches rhs.shape[1] when transpose_rhs is False or rhs.shape[2] when transpose_rhs is True, and raise a clear assertion error mentioning gmm_ref, lhs, rhs, and transpose_rhs; update the checks near the existing rank assertions so callers get a deterministic, informative message instead of a downstream matmul runtime error.tests/ops/gmm/test_cpu_ref.py (1)
72-94: Consider adding an empty group test fortgmm_ref.
TestGmmRefincludestest_empty_groupto verify empty group handling, butTestTgmmReflacks a similar test. Sincetgmm_forwardusesvisit_empty_groups=True, testing thattgmm_refcorrectly produces a zero output tensor for empty groups would improve coverage.Example test case
def test_empty_group(self): """Empty group produces zero output for that group.""" lhs = jnp.array([[1.0], [2.0]], dtype=jnp.float32) rhs = jnp.array([[3.0], [4.0]], dtype=jnp.float32) gs = jnp.array([0, 2], dtype=jnp.int32) out = tgmm_ref(lhs, rhs, gs) # Group 0: empty -> zeros # Group 1: [1,2]^T @ [3,4] = [[1*3+2*4]] = [[11]] expected = jnp.array([[[0.0]], [[11.0]]], dtype=jnp.float32) np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_cpu_ref.py` around lines 72 - 94, Add a test_empty_group to TestTgmmRef that mirrors TestGmmRef's empty-group case: call tgmm_ref with group ids that include an empty group index (e.g., gs = jnp.array([0,2])), assert that the output for the empty group is a zero tensor and other groups produce the correct outer-product results; locate TestTgmmRef and add a new method test_empty_group that builds lhs, rhs, gs, calls tgmm_ref and uses np.testing.assert_allclose to compare against the expected array containing zeros for the empty group.tests/ops/gmm/test_metadata.py (1)
36-41: Prefix unused unpacked variables with underscore.Static analysis flagged several unused unpacked variables. Prefixing them with
_follows Python convention and silences the warnings.Proposed fixes
def test_uniform_groups_multi_tile(self): - (offsets, gids, mids), num_tiles = make_group_metadata( + (_offsets, gids, mids), num_tiles = make_group_metadata( group_sizes=gs, m=512, tm=128 ) def test_shared_tile_at_boundary(self): - (offsets, gids, mids), num_tiles = make_group_metadata( + (_offsets, gids, mids), num_tiles = make_group_metadata( group_sizes=gs, m=128, tm=128 ) def test_empty_group(self): - (offsets, gids, mids), num_tiles = make_group_metadata( + (_offsets, gids, _mids), num_tiles = make_group_metadata( group_sizes=gs, m=128, tm=128 ) def test_visit_empty_groups(self): - (offsets, gids, mids), num_tiles = make_group_metadata( + (_offsets, gids, _mids), num_tiles = make_group_metadata( group_sizes=gs, m=128, tm=128, visit_empty_groups=True )Also applies to: 48-54, 59-63, 68-72
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_metadata.py` around lines 36 - 41, The test unpacks (offsets, gids, mids), num_tiles = make_group_metadata(...) but some of those unpacked variables are unused; update the test to prefix unused variables with an underscore (e.g., use _offsets, _gids, _mids or _offsets where appropriate) so static analysis warnings are silenced; apply the same change for the other unpack patterns in this file (the blocks around lines 48-54, 59-63, 68-72) while keeping references to make_group_metadata and num_tiles intact.tops/ops/gmm/metadata.py (1)
15-48: Consider addinggroup_sizesdtype validation.The function validates
group_sizes.ndimbut not its dtype. The callers (gmm_forward,tgmm_forward) validategroup_sizes.dtype == jnp.int32, but adding validation here would provide defense-in-depth and clearer error messages at the source.Proposed validation
assert group_sizes.ndim == 1, "group_sizes must be 1-D" + assert group_sizes.dtype == jnp.int32, f"group_sizes must be int32, got {group_sizes.dtype}" assert m > 0, "m must be positive" assert tm > 0, "tm must be positive"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/gmm/metadata.py` around lines 15 - 48, Add a dtype check for group_sizes inside make_group_metadata to ensure defensive validation: after the existing assert that group_sizes.ndim == 1, assert that group_sizes.dtype == jnp.int32 (or jnp.int32 equivalent) and raise a clear error message like "group_sizes must be int32" so callers (e.g., gmm_forward/tgmm_forward) get consistent validation at the source; keep the check in make_group_metadata to avoid silent type issues downstream and mirror the callers' expectations.docs/plans/2026-04-06-gmm-kernel-impl.md (2)
158-166: Note:_make_gmm_inputshelper defined in plan but not used intest_cpu_ref.py.The plan defines a
_make_gmm_inputshelper function for generating random test inputs, but the actualtests/ops/gmm/test_cpu_ref.pydoesn't include or use this helper. This is fine since the test file uses inline test data for clarity, but the plan could be updated to match if desired.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/plans/2026-04-06-gmm-kernel-impl.md` around lines 158 - 166, The plan includes a helper _make_gmm_inputs to generate random GMM test data but the actual test file test_cpu_ref.py doesn't import or use it; either remove or update the plan to match the test, or modify test_cpu_ref.py to import and use _make_gmm_inputs. Locate the helper definition named _make_gmm_inputs in the plan and the test module test_cpu_ref.py and either (a) delete or comment out the unused helper from docs/plans/2026-04-06-gmm-kernel-impl.md, or (b) update test_cpu_ref.py to call _make_gmm_inputs(m, k, n, num_groups, group_sizes, seed, dtype) and replace the inline test data with the generated lhs/rhs/gs so both the plan and test stay consistent.
1-13: Minor markdown formatting issues.Static analysis flagged two markdown issues:
- Line 13: Heading jumps from
#(h1) to###(h3), skipping h2- Line 1290: Code block lacks a language specifier
Suggested fixes
-### Task 1: CPU Reference Implementation +## Task 1: CPU Reference ImplementationAt line 1290:
-``` +```text tops/ops/gmm/metadata.py <- no internal deps🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/plans/2026-04-06-gmm-kernel-impl.md` around lines 1 - 13, Fix the markdown heading and code fence: change the third-level heading that immediately follows the top-level "# GMM (Grouped Matrix Multiplication) Implementation Plan" to a second-level heading (replace the `###` with `##`) so headings don’t skip h2, and add a language specifier to the fenced code block that currently shows "```text tops/ops/gmm/metadata.py ..." (update the backticks to include a language like ```text or ```bash) so the block has a proper language tag. Ensure the edits touch the top-level heading area and the code fence near the large code block referenced in the doc.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/plans/2026-04-06-gmm-kernel-design.md`:
- Around line 67-73: The docstring in the design doc is inconsistent with the
actual implementation in tops/ops/gmm/metadata.py: update the description of
make_group_metadata to match the code (or vice versa). Specifically, state that
group_offsets is the plain cumulative sum of unpadded group_sizes (not
tm-rounded offsets) and that m_tile_ids are global row-tile ids (not per-group
offsets); keep group_ids semantics as in the code. Reference
make_group_metadata, group_offsets, group_ids, and m_tile_ids when making the
doc edit so the doc and tops/ops/gmm/metadata.py agree.
- Around line 14-16: The fenced code examples (e.g., the block containing the
expression "out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i, :, :]" and
the other similar fenced blocks) lack language identifiers causing markdownlint
MD040; update each triple-backtick fence to include an explicit language tag
such as ```python (or ```text if it's plain output) so the blocks are properly
classified and rendered; ensure you update all occurrences mentioned (the blocks
around the example expression and the other fenced regions referenced) to use
the same convention.
In `@tops/ops/gmm/gmm.py`:
- Around line 221-240: Add a centralized validator (e.g.,
validate_group_partition) and call it from both gmm entry points before any
kernel/tiling work; the validator should assert group_sizes.ndim == 1, all
entries are integer and >= 0, group_sizes.sum() == m, and in gmm_forward also
assert group_sizes.shape[0] == rhs.shape[0] (or conversely that num_groups ==
group_sizes.shape[0]) so we fail fast on invalid partitions; place calls to this
validator alongside the existing asserts (near where lhs/rhs shapes and k are
validated) and keep using assert-style checks to satisfy the public-function
validation requirement.
- Around line 628-634: The backward path builds drhs by calling
tgmm_forward(lhs, grad, ...) which returns shape [num_groups, k, n] but when
transpose_rhs=True the primal rhs is [num_groups, n, k], so before casting drhs
to rhs.dtype you must transpose the last two axes (swap axis 1 and 2) to produce
[num_groups, n, k]; update the code around the drhs assignment in gmm.py to do
this conditional transpose when transpose_rhs is True and then
astype(rhs.dtype), and add a rectangular test case exercising transpose_rhs=True
(non-square n!=k) so the transposed gradient layout is validated.
---
Nitpick comments:
In `@docs/plans/2026-04-06-gmm-kernel-impl.md`:
- Around line 158-166: The plan includes a helper _make_gmm_inputs to generate
random GMM test data but the actual test file test_cpu_ref.py doesn't import or
use it; either remove or update the plan to match the test, or modify
test_cpu_ref.py to import and use _make_gmm_inputs. Locate the helper definition
named _make_gmm_inputs in the plan and the test module test_cpu_ref.py and
either (a) delete or comment out the unused helper from
docs/plans/2026-04-06-gmm-kernel-impl.md, or (b) update test_cpu_ref.py to call
_make_gmm_inputs(m, k, n, num_groups, group_sizes, seed, dtype) and replace the
inline test data with the generated lhs/rhs/gs so both the plan and test stay
consistent.
- Around line 1-13: Fix the markdown heading and code fence: change the
third-level heading that immediately follows the top-level "# GMM (Grouped
Matrix Multiplication) Implementation Plan" to a second-level heading (replace
the `###` with `##`) so headings don’t skip h2, and add a language specifier to
the fenced code block that currently shows "```text tops/ops/gmm/metadata.py
..." (update the backticks to include a language like ```text or ```bash) so the
block has a proper language tag. Ensure the edits touch the top-level heading
area and the code fence near the large code block referenced in the doc.
In `@tests/ops/gmm/test_cpu_ref.py`:
- Around line 72-94: Add a test_empty_group to TestTgmmRef that mirrors
TestGmmRef's empty-group case: call tgmm_ref with group ids that include an
empty group index (e.g., gs = jnp.array([0,2])), assert that the output for the
empty group is a zero tensor and other groups produce the correct outer-product
results; locate TestTgmmRef and add a new method test_empty_group that builds
lhs, rhs, gs, calls tgmm_ref and uses np.testing.assert_allclose to compare
against the expected array containing zeros for the empty group.
In `@tests/ops/gmm/test_metadata.py`:
- Around line 36-41: The test unpacks (offsets, gids, mids), num_tiles =
make_group_metadata(...) but some of those unpacked variables are unused; update
the test to prefix unused variables with an underscore (e.g., use _offsets,
_gids, _mids or _offsets where appropriate) so static analysis warnings are
silenced; apply the same change for the other unpack patterns in this file (the
blocks around lines 48-54, 59-63, 68-72) while keeping references to
make_group_metadata and num_tiles intact.
In `@tops/cpu/ops/gmm/naive.py`:
- Around line 9-54: Add an explicit shape compatibility assertion in gmm_ref:
verify k = lhs.shape[1] matches rhs.shape[1] when transpose_rhs is False or
rhs.shape[2] when transpose_rhs is True, and raise a clear assertion error
mentioning gmm_ref, lhs, rhs, and transpose_rhs; update the checks near the
existing rank assertions so callers get a deterministic, informative message
instead of a downstream matmul runtime error.
In `@tops/ops/gmm/metadata.py`:
- Around line 15-48: Add a dtype check for group_sizes inside
make_group_metadata to ensure defensive validation: after the existing assert
that group_sizes.ndim == 1, assert that group_sizes.dtype == jnp.int32 (or
jnp.int32 equivalent) and raise a clear error message like "group_sizes must be
int32" so callers (e.g., gmm_forward/tgmm_forward) get consistent validation at
the source; keep the check in make_group_metadata to avoid silent type issues
downstream and mirror the callers' expectations.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 65d1de2e-80e9-49c6-a742-fd46a7f97c8e
📒 Files selected for processing (12)
docs/plans/2026-04-06-gmm-kernel-design.mddocs/plans/2026-04-06-gmm-kernel-impl.mdtests/ops/gmm/__init__.pytests/ops/gmm/test_cpu_ref.pytests/ops/gmm/test_gmm_tpu.pytests/ops/gmm/test_metadata.pytops/cpu/ops/gmm/__init__.pytops/cpu/ops/gmm/naive.pytops/ops/__init__.pytops/ops/gmm/__init__.pytops/ops/gmm/gmm.pytops/ops/gmm/metadata.py
| ``` | ||
| out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i, :, :] | ||
| ``` |
There was a problem hiding this comment.
Add language identifiers to these fenced blocks.
markdownlint is already flagging these examples under MD040. Adding text/python language tags will clear the warning and improve rendering.
Also applies to: 19-21, 58-61, 92-95, 110-123
🧰 Tools
🪛 markdownlint-cli2 (0.22.0)
[warning] 14-14: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/plans/2026-04-06-gmm-kernel-design.md` around lines 14 - 16, The fenced
code examples (e.g., the block containing the expression "out[start_i:end_i, :]
= lhs[start_i:end_i, :] @ rhs[i, :, :]" and the other similar fenced blocks)
lack language identifiers causing markdownlint MD040; update each
triple-backtick fence to include an explicit language tag such as ```python (or
```text if it's plain output) so the blocks are properly classified and
rendered; ensure you update all occurrences mentioned (the blocks around the
example expression and the other fenced regions referenced) to use the same
convention.
| ### Group Metadata (computed on host, passed via scalar prefetch) | ||
|
|
||
| `make_group_metadata(group_sizes, m, tm)` produces: | ||
| - `group_offsets`: CSR-style cumulative row offsets, rounded to tm boundaries | ||
| - `group_ids`: maps each active m-tile index to its group | ||
| - `m_tile_ids`: maps each active m-tile index to its row-tile offset within the group | ||
|
|
There was a problem hiding this comment.
The metadata contract here is out of sync with make_group_metadata.
In tops/ops/gmm/metadata.py, group_offsets stays as the plain cumulative sum of unpadded group_sizes, and m_tile_ids are global row-tile ids. Describing them here as tm-rounded offsets and per-group tile offsets documents a different scheduler than the one in code.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/plans/2026-04-06-gmm-kernel-design.md` around lines 67 - 73, The
docstring in the design doc is inconsistent with the actual implementation in
tops/ops/gmm/metadata.py: update the description of make_group_metadata to match
the code (or vice versa). Specifically, state that group_offsets is the plain
cumulative sum of unpadded group_sizes (not tm-rounded offsets) and that
m_tile_ids are global row-tile ids (not per-group offsets); keep group_ids
semantics as in the code. Reference make_group_metadata, group_offsets,
group_ids, and m_tile_ids when making the doc edit so the doc and
tops/ops/gmm/metadata.py agree.
tops/ops/gmm/gmm.py
Outdated
| assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" | ||
| assert rhs.ndim == 3, f"rhs must be 3D, got {rhs.ndim}D" | ||
| assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" | ||
|
|
||
| if interpret is None: | ||
| interpret = get_interpret() | ||
| if preferred_element_type is None: | ||
| preferred_element_type = lhs.dtype | ||
|
|
||
| m, k_lhs = lhs.shape | ||
| num_groups = rhs.shape[0] | ||
|
|
||
| if transpose_rhs: | ||
| n, k_rhs = rhs.shape[1], rhs.shape[2] | ||
| else: | ||
| k_rhs, n = rhs.shape[1], rhs.shape[2] | ||
| assert k_lhs == k_rhs, f"lhs K ({k_lhs}) must match rhs K ({k_rhs})" | ||
|
|
||
| k = k_lhs | ||
| tm, tk, tn = _validate_tiling(tiling, m, k, n) |
There was a problem hiding this comment.
Validate the group partition before building metadata.
Both entry points only validate rank/K compatibility today. If group_sizes is non-integer/negative, does not sum to m, or—gmm_forward only—doesn't match rhs.shape[0], the scheduler can leave rows unwritten or address the wrong expert buffer instead of failing fast. Please centralize those invariants in a shared validator before launching the kernels. As per coding guidelines, "All public functions must enforce strict constraints and validation on the shape and types of input variables using assert instructions (or utilities like assert_shape_or_none from tops.utils) before executing the main logic"
Also applies to: 460-476
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/gmm/gmm.py` around lines 221 - 240, Add a centralized validator
(e.g., validate_group_partition) and call it from both gmm entry points before
any kernel/tiling work; the validator should assert group_sizes.ndim == 1, all
entries are integer and >= 0, group_sizes.sum() == m, and in gmm_forward also
assert group_sizes.shape[0] == rhs.shape[0] (or conversely that num_groups ==
group_sizes.shape[0]) so we fail fast on invalid partitions; place calls to this
validator alongside the existing asserts (near where lhs/rhs shapes and k are
validated) and keep using assert-style checks to satisfy the public-function
validation requirement.
tops/ops/gmm/gmm.py
Outdated
| drhs = tgmm_forward( | ||
| lhs, | ||
| grad, | ||
| group_sizes, | ||
| tiling=tiling, | ||
| preferred_element_type=preferred_element_type, | ||
| ).astype(rhs.dtype) |
There was a problem hiding this comment.
Transpose drhs in the transpose_rhs=True backward path.
tgmm_forward(lhs, grad, ...) always returns [num_groups, k, n], but the primal rhs is [num_groups, n, k] when transpose_rhs=True. The custom VJP therefore returns a cotangent with the wrong layout here: rectangular cases will fail, and square cases silently get the transposed weight gradient. Please swap the last two axes before casting back to rhs.dtype; also add a transpose_rhs=True rectangular gradient case so this path stays covered.
💡 Suggested fix
drhs = tgmm_forward(
lhs,
grad,
group_sizes,
tiling=tiling,
preferred_element_type=preferred_element_type,
- ).astype(rhs.dtype)
+ )
+ if transpose_rhs:
+ drhs = jnp.swapaxes(drhs, -1, -2)
+ drhs = drhs.astype(rhs.dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/gmm/gmm.py` around lines 628 - 634, The backward path builds drhs by
calling tgmm_forward(lhs, grad, ...) which returns shape [num_groups, k, n] but
when transpose_rhs=True the primal rhs is [num_groups, n, k], so before casting
drhs to rhs.dtype you must transpose the last two axes (swap axis 1 and 2) to
produce [num_groups, n, k]; update the code around the drhs assignment in gmm.py
to do this conditional transpose when transpose_rhs is True and then
astype(rhs.dtype), and add a rectangular test case exercising transpose_rhs=True
(non-square n!=k) so the transposed gradient layout is validated.
Design and implementation plan documents are not needed in the final codebase — the code and tests serve as the authoritative reference. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The docstring incorrectly stated rhs shape as [num_groups, k, n] when transpose_rhs=True. The actual input shape is [num_groups, n, k], with each slice transposed to [k, n] before matmul. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… in refs Remove all Pallas TPU kernel code (gmm.py, metadata.py) and their tests (test_gmm_tpu.py, test_metadata.py). Retain CPU reference implementations as the ground truth. Change naive refs from "cast to f32 then multiply" to "bf16 multiply with f32 accumulation" via lax.dot(preferred_element_type=f32), matching TPU MXU semantics. Output is now f32 directly instead of casting back to input dtype. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Change all test inputs from float32 to bfloat16 to match the intended usage pattern (bf16 mul + f32 accumulation). Expected values remain f32 as the reference now outputs f32. Tolerances relaxed to 1e-2. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Implements _gmm_impl and _tgmm_impl with scan-based grouped matmul using dynamic_slice for TPU/JIT compatibility. bf16 inputs with f32 accumulation to match TPU MXU semantics. Adds tests verifying JAX implementations match CPU reference outputs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
🧹 Nitpick comments (4)
tests/ops/gmm/test_gmm_vs_tokamax.py (3)
59-67: Same issue: function call in default argument.Apply the same fix as
_make_inputsabove.♻️ Proposed fix
-def _make_inputs_transposed(case, key=jax.random.PRNGKey(42)): +def _make_inputs_transposed(case, key=None): + if key is None: + key = jax.random.PRNGKey(42) k1, k2 = jax.random.split(key)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_gmm_vs_tokamax.py` around lines 59 - 67, The function _make_inputs_transposed uses a function call jax.random.PRNGKey(42) as a default argument which is evaluated at import time; change the signature to accept key=None and inside _make_inputs_transposed set key = jax.random.PRNGKey(42) if key is None, then proceed to call jax.random.split(key) and generate lhs, rhs, and gs as before (same pattern used in _make_inputs). Update references to the parameter name if needed but keep the function name and behavior unchanged otherwise.
145-163: Consider expanding backward test coverage.The backward tests only run on
CASES[:2], leaving the more complex cases (8-group, varying group sizes) untested for gradients. The comment from past reviews also noted missingtranspose_rhs=Truegradient tests. Consider adding at least onetranspose_rhs=Truebackward case to validate the transposed gradient path, especially given the past review flagging a potential layout issue in that path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_gmm_vs_tokamax.py` around lines 145 - 163, The test_grad_lhs currently only paramaterizes CASES[:2], missing complex scenarios and the transpose_rhs=True path; update the parametrize on test_grad_lhs to include additional CASES (at least the 8-group / varying group-size cases) and add a separate parametrized test or include a case where _call_tokamax_gmm and gmm are invoked with transpose_rhs=True so the backward/transposed gradient path is exercised (ensure the new tests call jax.grad over tokamax_loss and jax_loss as in test_grad_lhs to compare gradients).
49-56: Avoid function call in default argument.
jax.random.PRNGKey(42)is evaluated once at function definition time, not at call time. While deterministic here, this pattern can cause subtle bugs if the key were ever mutated or if the function is called multiple times expecting fresh state. Move the default inside the function body.♻️ Proposed fix
-def _make_inputs(case, key=jax.random.PRNGKey(42)): +def _make_inputs(case, key=None): + if key is None: + key = jax.random.PRNGKey(42) k1, k2 = jax.random.split(key) lhs = jax.random.normal(k1, (case["m"], case["k"]), dtype=jnp.bfloat16) rhs = jax.random.normal( k2, (case["num_groups"], case["k"], case["n"]), dtype=jnp.bfloat16 ) gs = jnp.array(case["group_sizes"], dtype=jnp.int32) return lhs, rhs, gs🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_gmm_vs_tokamax.py` around lines 49 - 56, The _make_inputs function currently uses jax.random.PRNGKey(42) as a default argument which is evaluated at definition time; change the signature to accept key=None (or omit the default) and inside the function do if key is None: key = jax.random.PRNGKey(42) before splitting; keep the rest of the logic (k1, k2 = jax.random.split(key), lhs/rhs/gs creation) unchanged so each call gets a fresh key when desired and avoids evaluating the PRNGKey at import time.tops/ops/gmm/gmm.py (1)
41-41: Minor: offset array creation could use explicit dtype.The
jnp.zeros(1, dtype=jnp.int32)is good, butjnp.cumsum(group_sizes)inherits dtype fromgroup_sizes. This is fine assuminggroup_sizesis always int32, but an explicit cast could guard against unexpected input dtypes.♻️ Optional: explicit dtype for cumsum
- offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), jnp.cumsum(group_sizes)]) + offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), jnp.cumsum(group_sizes, dtype=jnp.int32)])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/gmm/gmm.py` at line 41, The offsets construction uses jnp.cumsum(group_sizes) which may inherit an unintended dtype; update the offsets creation (the line building offsets that concatenates jnp.zeros and jnp.cumsum) to ensure jnp.cumsum(group_sizes) is explicitly cast to the desired dtype (e.g., jnp.int32) before concatenation so both parts share the same dtype and avoid dtype-mismatch issues with the offsets variable.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/ops/gmm/test_gmm_vs_tokamax.py`:
- Around line 59-67: The function _make_inputs_transposed uses a function call
jax.random.PRNGKey(42) as a default argument which is evaluated at import time;
change the signature to accept key=None and inside _make_inputs_transposed set
key = jax.random.PRNGKey(42) if key is None, then proceed to call
jax.random.split(key) and generate lhs, rhs, and gs as before (same pattern used
in _make_inputs). Update references to the parameter name if needed but keep the
function name and behavior unchanged otherwise.
- Around line 145-163: The test_grad_lhs currently only paramaterizes CASES[:2],
missing complex scenarios and the transpose_rhs=True path; update the
parametrize on test_grad_lhs to include additional CASES (at least the 8-group /
varying group-size cases) and add a separate parametrized test or include a case
where _call_tokamax_gmm and gmm are invoked with transpose_rhs=True so the
backward/transposed gradient path is exercised (ensure the new tests call
jax.grad over tokamax_loss and jax_loss as in test_grad_lhs to compare
gradients).
- Around line 49-56: The _make_inputs function currently uses
jax.random.PRNGKey(42) as a default argument which is evaluated at definition
time; change the signature to accept key=None (or omit the default) and inside
the function do if key is None: key = jax.random.PRNGKey(42) before splitting;
keep the rest of the logic (k1, k2 = jax.random.split(key), lhs/rhs/gs creation)
unchanged so each call gets a fresh key when desired and avoids evaluating the
PRNGKey at import time.
In `@tops/ops/gmm/gmm.py`:
- Line 41: The offsets construction uses jnp.cumsum(group_sizes) which may
inherit an unintended dtype; update the offsets creation (the line building
offsets that concatenates jnp.zeros and jnp.cumsum) to ensure
jnp.cumsum(group_sizes) is explicitly cast to the desired dtype (e.g.,
jnp.int32) before concatenation so both parts share the same dtype and avoid
dtype-mismatch issues with the offsets variable.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 27d82765-d49b-4c95-b472-1ff6830d93d4
📒 Files selected for processing (5)
pyproject.tomltests/ops/gmm/test_cpu_ref.pytests/ops/gmm/test_gmm_vs_tokamax.pytops/ops/gmm/__init__.pytops/ops/gmm/gmm.py
✅ Files skipped from review due to trivial changes (2)
- pyproject.toml
- tests/ops/gmm/test_cpu_ref.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tops/ops/gmm/init.py
Summary
gmm_ref,tgmm_ref) with bf16 mul / f32 accumulationgmm()andtgmm()usinglax.scan+dynamic_slicethat run on TPU via XLA compilationcustom_vjpfor differentiable GMM: forward via GMM, backward dlhs via GMM (transposed), backward drhs via TGMMtokamax>=0.0.12as optional TPU dependency for comparison testingKey Design Decisions
lax.scaniterates over groups sequentially;dynamic_slicewith padded arrays ([2m, k]) avoids OOB clamping;lax.dotwithpreferred_element_type=jnp.float32for bf16 mul / f32 accumulationdlhs = gmm(grad, rhs, group_sizes, not transpose_rhs),drhs = tgmm(lhs, grad, group_sizes)(or swapped args whentranspose_rhs=True)pytest.importorskip+jax.default_backend() != "tpu"skip — gracefully skip on non-TPU environmentsFiles
tops/ops/gmm/gmm.py—_gmm_impl,_tgmm_impl,gmm(custom_vjp),tgmmtops/ops/gmm/__init__.py— Public API exportstops/cpu/ops/gmm/naive.py— CPU reference implementationstests/ops/gmm/test_cpu_ref.py— 15 tests (CPU ref + JAX impl + gradients)tests/ops/gmm/test_gmm_vs_tokamax.py— 16 parametrized tokamax comparison testspyproject.toml—tokamax>=0.0.12in[tpu]extrasTest plan
🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Tests
Chores