diff --git a/pyproject.toml b/pyproject.toml index 0a5e44be..7c3e1ab8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ gpu = [ tpu = [ "jax[tpu]>=0.8.1", "torch", + "tokamax>=0.0.12", ] profile = [ "xprof==2.22.0", diff --git a/tests/ops/gmm/__init__.py b/tests/ops/gmm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ops/gmm/test_cpu_ref.py b/tests/ops/gmm/test_cpu_ref.py new file mode 100644 index 00000000..024ba5d1 --- /dev/null +++ b/tests/ops/gmm/test_cpu_ref.py @@ -0,0 +1,222 @@ +"""Verify CPU reference implementations for GMM/TGMM are correct.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from tops.cpu.ops.gmm import gmm_ref, tgmm_ref +from tops.ops.gmm import gmm, tgmm + + +class TestGmmRef: + """Test gmm_ref against manual numpy computation.""" + + def test_single_group(self): + """Single group = standard matmul.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=jnp.bfloat16) + gs = jnp.array([2], dtype=jnp.int32) + out = gmm_ref(lhs, rhs, gs) + expected = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2) + + def test_two_groups(self): + """Two groups with different weights.""" + lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.bfloat16) + rhs = jnp.array( + [ + [[2.0, 0.0], [0.0, 2.0]], # group 0: scale by 2 + [[0.0, 1.0], [1.0, 0.0]], # group 1: swap columns + ], + dtype=jnp.bfloat16, + ) + gs = jnp.array([1, 2], dtype=jnp.int32) + out = gmm_ref(lhs, rhs, gs) + expected = jnp.array([[2.0, 0.0], [1.0, 0.0], [1.0, 1.0]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2) + + def test_empty_group(self): + """Empty group produces zeros for those rows (none exist).""" + lhs = jnp.array([[1.0, 2.0]], dtype=jnp.bfloat16) + rhs = jnp.array( + [ + [[1.0], [1.0]], # group 0: empty + [[1.0], [1.0]], # group 1: 1 row + ], + dtype=jnp.bfloat16, + ) + gs = jnp.array([0, 1], dtype=jnp.int32) + out = gmm_ref(lhs, rhs, gs) + expected = jnp.array([[3.0]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2) + + def test_transpose_rhs(self): + """transpose_rhs transposes each rhs[i] before matmul.""" + lhs = jnp.array([[1.0, 2.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.bfloat16) + gs = jnp.array([1], dtype=jnp.int32) + # Without transpose: lhs [1,2] @ rhs [2,2] = [1*3+2*5, 1*4+2*6] = [13, 16] + out_normal = gmm_ref(lhs, rhs, gs) + np.testing.assert_allclose(np.array(out_normal), [[13.0, 16.0]], atol=1e-2) + # With transpose: lhs [1,2] @ rhs.T [2,2] = [1*3+2*4, 1*5+2*6] = [11, 17] + out_transposed = gmm_ref(lhs, rhs, gs, transpose_rhs=True) + np.testing.assert_allclose(np.array(out_transposed), [[11.0, 17.0]], atol=1e-2) + + +class TestTgmmRef: + """Test tgmm_ref against manual numpy computation.""" + + def test_single_group(self): + """Single group: lhs^T @ rhs.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[5.0], [6.0]], dtype=jnp.bfloat16) + gs = jnp.array([2], dtype=jnp.int32) + out = tgmm_ref(lhs, rhs, gs) + # lhs^T [2,2] @ rhs [2,1] = [[1*5+3*6], [2*5+4*6]] = [[23], [34]] + expected = jnp.array([[[23.0], [34.0]]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2) + + def test_two_groups(self): + """Two groups produce separate outer products.""" + lhs = jnp.array([[1.0], [2.0], [3.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[4.0], [5.0], [6.0]], dtype=jnp.bfloat16) + gs = jnp.array([1, 2], dtype=jnp.int32) + out = tgmm_ref(lhs, rhs, gs) + # Group 0: [1]^T @ [4] = [[4]] + # Group 1: [2,3]^T @ [5,6] = [[2*5+3*6]] = [[28]] + expected = jnp.array([[[4.0]], [[28.0]]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2) + + +class TestGmmJax: + """Test JIT-compilable gmm against CPU reference.""" + + def test_single_group(self): + """Single group: gmm matches gmm_ref.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=jnp.bfloat16) + gs = jnp.array([2], dtype=jnp.int32) + out = gmm(lhs, rhs, gs) + ref = gmm_ref(lhs, rhs, gs) + np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2) + + def test_two_groups(self): + """Two groups: gmm matches gmm_ref.""" + lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.bfloat16) + rhs = jnp.array( + [ + [[2.0, 0.0], [0.0, 2.0]], + [[0.0, 1.0], [1.0, 0.0]], + ], + dtype=jnp.bfloat16, + ) + gs = jnp.array([1, 2], dtype=jnp.int32) + out = gmm(lhs, rhs, gs) + ref = gmm_ref(lhs, rhs, gs) + np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2) + + def test_transpose_rhs(self): + """transpose_rhs: gmm matches gmm_ref.""" + lhs = jnp.array([[1.0, 2.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.bfloat16) + gs = jnp.array([1], dtype=jnp.int32) + out = gmm(lhs, rhs, gs, transpose_rhs=True) + ref = gmm_ref(lhs, rhs, gs, transpose_rhs=True) + np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2) + + +class TestTgmmJax: + """Test JIT-compilable tgmm against CPU reference.""" + + def test_single_group(self): + """Single group: tgmm matches tgmm_ref.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[5.0], [6.0]], dtype=jnp.bfloat16) + gs = jnp.array([2], dtype=jnp.int32) + out = tgmm(lhs, rhs, gs) + ref = tgmm_ref(lhs, rhs, gs) + np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2) + + def test_two_groups(self): + """Two groups: tgmm matches tgmm_ref.""" + lhs = jnp.array([[1.0], [2.0], [3.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[4.0], [5.0], [6.0]], dtype=jnp.bfloat16) + gs = jnp.array([1, 2], dtype=jnp.int32) + out = tgmm(lhs, rhs, gs) + ref = tgmm_ref(lhs, rhs, gs) + np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2) + + +class TestGmmGrad: + """Test gmm gradient via custom_vjp.""" + + def test_grad_lhs(self): + """Gradient w.r.t. lhs is correct.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[1.0, 0.5], [0.5, 1.0]]], dtype=jnp.bfloat16) + gs = jnp.array([2], dtype=jnp.int32) + + def loss_fn(x): + return gmm(x, rhs, gs).sum() + + grad = jax.grad(loss_fn)(lhs) + assert grad.shape == lhs.shape + assert not jnp.any(jnp.isnan(grad)) + + def test_grad_rhs(self): + """Gradient w.r.t. rhs is correct.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[1.0, 0.5], [0.5, 1.0]]], dtype=jnp.bfloat16) + gs = jnp.array([2], dtype=jnp.int32) + + def loss_fn(w): + return gmm(lhs, w, gs).sum() + + grad = jax.grad(loss_fn)(rhs) + assert grad.shape == rhs.shape + assert not jnp.any(jnp.isnan(grad)) + + def test_grad_both(self): + """Gradient w.r.t. both lhs and rhs, two groups.""" + lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.bfloat16) + rhs = jnp.array( + [[[2.0, 0.0], [0.0, 2.0]], [[0.0, 1.0], [1.0, 0.0]]], + dtype=jnp.bfloat16, + ) + gs = jnp.array([1, 2], dtype=jnp.int32) + + def loss_fn(x, w): + return gmm(x, w, gs).sum() + + grad_lhs, grad_rhs = jax.grad(loss_fn, argnums=(0, 1))(lhs, rhs) + assert grad_lhs.shape == lhs.shape + assert grad_rhs.shape == rhs.shape + assert not jnp.any(jnp.isnan(grad_lhs)) + assert not jnp.any(jnp.isnan(grad_rhs)) + + def test_grad_transpose_rhs(self): + """Gradient with transpose_rhs=True.""" + lhs = jnp.array([[1.0, 2.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.bfloat16) + gs = jnp.array([1], dtype=jnp.int32) + + def loss_fn(x, w): + return gmm(x, w, gs, transpose_rhs=True).sum() + + grad_lhs, grad_rhs = jax.grad(loss_fn, argnums=(0, 1))(lhs, rhs) + assert grad_lhs.shape == lhs.shape + assert grad_rhs.shape == rhs.shape + assert not jnp.any(jnp.isnan(grad_lhs)) + assert not jnp.any(jnp.isnan(grad_rhs)) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ops/gmm/test_gmm_vs_tokamax.py b/tests/ops/gmm/test_gmm_vs_tokamax.py new file mode 100644 index 00000000..5f9b1d74 --- /dev/null +++ b/tests/ops/gmm/test_gmm_vs_tokamax.py @@ -0,0 +1,187 @@ +"""Compare JAX GMM against tokamax GMM in bf16 on TPU. + +Tests skip automatically if tokamax is not installed or not on TPU. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from tops.ops.gmm import gmm +from tops.cpu.ops.gmm import gmm_ref + +# Skip entire module if tokamax is not installed. +tokamax_kernel = pytest.importorskip( + "tokamax._src.ops.ragged_dot.pallas_mosaic_tpu_kernel" +) + +# Skip if not running on TPU. +pytestmark = pytest.mark.skipif( + jax.default_backend() != "tpu", + reason="tokamax comparison requires TPU", +) + + +# --------------------------------------------------------------------------- +# Test cases: TPU-aligned shapes (multiples of 128) +# --------------------------------------------------------------------------- + +CASES = [ + dict(m=128, k=128, n=128, num_groups=1, group_sizes=[128]), + dict(m=256, k=128, n=128, num_groups=4, group_sizes=[64, 64, 64, 64]), + dict(m=512, k=256, n=256, num_groups=8, group_sizes=[64] * 8), + dict(m=384, k=128, n=128, num_groups=4, group_sizes=[128, 64, 128, 64]), +] + + +def _case_id(case): + return f"m{case['m']}_k{case['k']}_n{case['n']}_g{case['num_groups']}" + + +def _make_inputs(case, key=jax.random.PRNGKey(42)): + k1, k2 = jax.random.split(key) + lhs = jax.random.normal(k1, (case["m"], case["k"]), dtype=jnp.bfloat16) + rhs = jax.random.normal( + k2, (case["num_groups"], case["k"], case["n"]), dtype=jnp.bfloat16 + ) + gs = jnp.array(case["group_sizes"], dtype=jnp.int32) + return lhs, rhs, gs + + +def _make_inputs_transposed(case, key=jax.random.PRNGKey(42)): + k1, k2 = jax.random.split(key) + lhs = jax.random.normal(k1, (case["m"], case["k"]), dtype=jnp.bfloat16) + # For transpose_rhs, rhs shape is [num_groups, n, k] + rhs = jax.random.normal( + k2, (case["num_groups"], case["n"], case["k"]), dtype=jnp.bfloat16 + ) + gs = jnp.array(case["group_sizes"], dtype=jnp.int32) + return lhs, rhs, gs + + +def _call_tokamax_gmm(lhs, rhs, gs, transpose_rhs=False): + """Call tokamax GMM with the standard calling convention.""" + return tokamax_kernel.gmm( + lhs=lhs, + rhs=rhs, + group_sizes=gs, + precision=jax.lax.Precision.DEFAULT, + out_dtype=jnp.float32, + tiling=(128, 128, 128), + transpose_rhs=transpose_rhs, + interpret=False, + ) + + +# --------------------------------------------------------------------------- +# Forward tests +# --------------------------------------------------------------------------- + + +class TestGmmForwardVsTokamax: + """Compare forward output of JAX gmm vs tokamax gmm.""" + + @pytest.mark.parametrize("case", CASES, ids=[_case_id(c) for c in CASES]) + def test_forward_bf16(self, case): + lhs, rhs, gs = _make_inputs(case) + jax_out = gmm(lhs, rhs, gs) + tokamax_out = _call_tokamax_gmm(lhs, rhs, gs) + np.testing.assert_allclose( + np.array(jax_out), + np.array(tokamax_out), + atol=1e-2, + rtol=1e-2, + ) + + @pytest.mark.parametrize("case", CASES, ids=[_case_id(c) for c in CASES]) + def test_forward_transpose_rhs(self, case): + lhs, rhs, gs = _make_inputs_transposed(case) + jax_out = gmm(lhs, rhs, gs, transpose_rhs=True) + tokamax_out = _call_tokamax_gmm(lhs, rhs, gs, transpose_rhs=True) + np.testing.assert_allclose( + np.array(jax_out), + np.array(tokamax_out), + atol=1e-2, + rtol=1e-2, + ) + + @pytest.mark.parametrize("case", CASES, ids=[_case_id(c) for c in CASES]) + def test_forward_vs_cpu_ref(self, case): + """Both JAX and tokamax should match the CPU reference.""" + lhs, rhs, gs = _make_inputs(case) + jax_out = gmm(lhs, rhs, gs) + tokamax_out = _call_tokamax_gmm(lhs, rhs, gs) + ref_out = gmm_ref(lhs, rhs, gs) + np.testing.assert_allclose( + np.array(jax_out), + np.array(ref_out), + atol=1e-2, + rtol=1e-2, + ) + np.testing.assert_allclose( + np.array(tokamax_out), + np.array(ref_out), + atol=1e-2, + rtol=1e-2, + ) + + +# --------------------------------------------------------------------------- +# Backward tests +# --------------------------------------------------------------------------- + + +class TestGmmBackwardVsTokamax: + """Compare backward gradients of JAX gmm vs tokamax gmm.""" + + @pytest.mark.parametrize("case", CASES[:2], ids=[_case_id(c) for c in CASES[:2]]) + def test_grad_lhs(self, case): + """dlhs should match between JAX and tokamax.""" + lhs, rhs, gs = _make_inputs(case) + + def jax_loss(x): + return gmm(x, rhs, gs).sum() + + def tokamax_loss(x): + return _call_tokamax_gmm(x, rhs, gs).sum() + + jax_grad = jax.grad(jax_loss)(lhs) + tokamax_grad = jax.grad(tokamax_loss)(lhs) + np.testing.assert_allclose( + np.array(jax_grad), + np.array(tokamax_grad), + atol=1e-1, + rtol=1e-1, + ) + + @pytest.mark.parametrize("case", CASES[:2], ids=[_case_id(c) for c in CASES[:2]]) + def test_grad_rhs(self, case): + """drhs should match between JAX and tokamax.""" + lhs, rhs, gs = _make_inputs(case) + + def jax_loss(w): + return gmm(lhs, w, gs).sum() + + def tokamax_loss(w): + return _call_tokamax_gmm(lhs, w, gs).sum() + + jax_grad = jax.grad(jax_loss)(rhs) + tokamax_grad = jax.grad(tokamax_loss)(rhs) + np.testing.assert_allclose( + np.array(jax_grad), + np.array(tokamax_grad), + atol=1e-1, + rtol=1e-1, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tops/cpu/ops/gmm/__init__.py b/tops/cpu/ops/gmm/__init__.py new file mode 100644 index 00000000..f8379c46 --- /dev/null +++ b/tops/cpu/ops/gmm/__init__.py @@ -0,0 +1,3 @@ +from .naive import gmm_ref, tgmm_ref + +__all__ = ["gmm_ref", "tgmm_ref"] diff --git a/tops/cpu/ops/gmm/naive.py b/tops/cpu/ops/gmm/naive.py new file mode 100644 index 00000000..8b146346 --- /dev/null +++ b/tops/cpu/ops/gmm/naive.py @@ -0,0 +1,108 @@ +"""Pure JAX CPU reference for Grouped Matrix Multiplication. + +Uses bf16 multiplication with f32 accumulation to match TPU MXU semantics. +""" + +import jax +import jax.numpy as jnp +from jax import lax + +from tops.cpu.ops import cpu_reference + + +@cpu_reference +def gmm_ref( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + transpose_rhs: bool = False, +) -> jax.Array: + """Grouped matrix multiplication reference implementation. + + For each group i with rows [start_i, end_i): + out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + + Uses bf16 multiplication with f32 accumulation (``preferred_element_type``). + + Args: + lhs: [m, k] input activations. + rhs: [num_groups, k, n] per-group weights. + If transpose_rhs=True, rhs is [num_groups, n, k] and each slice + is transposed to [k, n] before matmul. + group_sizes: [num_groups] int32, number of rows per group. + transpose_rhs: If True, transpose each rhs[i] before matmul. + + Returns: + [m, output_dim] where output_dim = rhs.shape[2] if not transpose_rhs + else rhs.shape[1]. + """ + assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" + assert rhs.ndim == 3, f"rhs must be 3D, got {rhs.ndim}D" + assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" + + m = lhs.shape[0] + num_groups = rhs.shape[0] + n = rhs.shape[1] if transpose_rhs else rhs.shape[2] + + out = jnp.zeros((m, n), dtype=jnp.float32) + start = 0 + for i in range(num_groups): + size = int(group_sizes[i]) + end = start + size + if size > 0: + lhs_slice = lhs[start:end] + rhs_mat = rhs[i] + if transpose_rhs: + rhs_mat = rhs_mat.T + out = out.at[start:end].set( + lax.dot(lhs_slice, rhs_mat, preferred_element_type=jnp.float32) + ) + start = end + return out + + +@cpu_reference +def tgmm_ref( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, +) -> jax.Array: + """Transposed grouped matrix multiplication reference implementation. + + For each group i with rows [start_i, end_i): + out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] + + Uses bf16 multiplication with f32 accumulation (``preferred_element_type``). + + Args: + lhs: [m, k] input activations. + rhs: [m, n] gradient or second operand. + group_sizes: [num_groups] int32, number of rows per group. + + Returns: + [num_groups, k, n] per-group outer products. + """ + assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" + assert rhs.ndim == 2, f"rhs must be 2D, got {rhs.ndim}D" + assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" + assert lhs.shape[0] == rhs.shape[0], ( + f"lhs and rhs must have same m dim, got {lhs.shape[0]} vs {rhs.shape[0]}" + ) + + k = lhs.shape[1] + n = rhs.shape[1] + num_groups = group_sizes.shape[0] + + out = jnp.zeros((num_groups, k, n), dtype=jnp.float32) + start = 0 + for i in range(num_groups): + size = int(group_sizes[i]) + end = start + size + if size > 0: + lhs_slice = lhs[start:end] + rhs_slice = rhs[start:end] + out = out.at[i].set( + lax.dot(lhs_slice.T, rhs_slice, preferred_element_type=jnp.float32) + ) + start = end + return out diff --git a/tops/ops/gmm/__init__.py b/tops/ops/gmm/__init__.py new file mode 100644 index 00000000..1b99b467 --- /dev/null +++ b/tops/ops/gmm/__init__.py @@ -0,0 +1,5 @@ +"""Public API for grouped matrix multiplication.""" + +from .gmm import gmm, tgmm + +__all__ = ["gmm", "tgmm"] diff --git a/tops/ops/gmm/gmm.py b/tops/ops/gmm/gmm.py new file mode 100644 index 00000000..24082a9d --- /dev/null +++ b/tops/ops/gmm/gmm.py @@ -0,0 +1,195 @@ +"""JIT-compilable Grouped Matrix Multiplication for TPU. + +Uses lax.scan + dynamic_slice for TPU-compatible grouped matmul. +bf16 multiplication with f32 accumulation to match TPU MXU semantics. +""" + +import functools + +import jax +import jax.numpy as jnp +from jax import lax + + +def _gmm_impl( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + transpose_rhs: bool = False, +) -> jax.Array: + """Core scan-based grouped matmul. + + For each group i with rows [start_i, end_i): + out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + + Uses lax.scan over groups with dynamic_slice for JIT/TPU compatibility. + bf16 inputs, f32 accumulation via lax.dot preferred_element_type. + + Args: + lhs: [m, k] input activations. + rhs: [num_groups, k, n] per-group weights. + If transpose_rhs=True, rhs is [num_groups, n, k]. + group_sizes: [num_groups] int32, number of rows per group. + transpose_rhs: If True, transpose each rhs[i] before matmul. + + Returns: + [m, n] output in float32. + """ + m, k = lhs.shape + n = rhs.shape[1] if transpose_rhs else rhs.shape[2] + + offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), jnp.cumsum(group_sizes)]) + + # Pad lhs so dynamic_slice(start, m) never triggers OOB clamping. + lhs_padded = jnp.pad(lhs, ((0, m), (0, 0))) # [2m, k] + + def body(carry, i): + out_padded = carry # [2m, n] + start = offsets[i] + size = group_sizes[i] + + lhs_slice = lax.dynamic_slice(lhs_padded, (start, 0), (m, k)) + rhs_mat = rhs[i] + if transpose_rhs: + rhs_mat = rhs_mat.T + + prod = lax.dot(lhs_slice, rhs_mat, preferred_element_type=jnp.float32) + + # Zero out rows beyond this group's size. + valid = jnp.arange(m) < size + prod = prod * valid[:, None] + + out_padded = lax.dynamic_update_slice(out_padded, prod, (start, 0)) + return out_padded, None + + out_padded = jnp.zeros((2 * m, n), dtype=jnp.float32) + out_padded, _ = lax.scan(body, out_padded, jnp.arange(offsets.shape[0] - 1)) + return out_padded[:m] + + +def _tgmm_impl( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, +) -> jax.Array: + """Core scan-based transposed grouped matmul. + + For each group i with rows [start_i, end_i): + out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] + + Uses lax.scan over groups with dynamic_slice for JIT/TPU compatibility. + bf16 inputs, f32 accumulation via lax.dot preferred_element_type. + + Args: + lhs: [m, k] input activations. + rhs: [m, n] second operand. + group_sizes: [num_groups] int32, number of rows per group. + + Returns: + [num_groups, k, n] per-group products in float32. + """ + m, k = lhs.shape + n = rhs.shape[1] + + offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), jnp.cumsum(group_sizes)]) + + lhs_padded = jnp.pad(lhs, ((0, m), (0, 0))) # [2m, k] + rhs_padded = jnp.pad(rhs, ((0, m), (0, 0))) # [2m, n] + + def body(_, i): + start = offsets[i] + size = group_sizes[i] + + lhs_slice = lax.dynamic_slice(lhs_padded, (start, 0), (m, k)) + rhs_slice = lax.dynamic_slice(rhs_padded, (start, 0), (m, n)) + + # Mask invalid rows to zero (masking one operand suffices). + valid = jnp.arange(m) < size + lhs_slice = lhs_slice * valid[:, None] + + result = lax.dot( + lhs_slice.T, rhs_slice, preferred_element_type=jnp.float32 + ) # [k, n] + return None, result + + _, results = lax.scan(body, None, jnp.arange(group_sizes.shape[0])) + return results # [num_groups, k, n] + + +def _gmm_fwd(lhs, rhs, group_sizes, transpose_rhs): + """Forward rule: compute output and save residuals.""" + out = _gmm_impl(lhs, rhs, group_sizes, transpose_rhs) + return out, (lhs, rhs, group_sizes) + + +def _gmm_bwd(transpose_rhs, residuals, grad): + """Backward rule: compute dlhs via GMM, drhs via TGMM.""" + lhs, rhs, group_sizes = residuals + + # dlhs = grad @ W^T per group + dlhs = _gmm_impl(grad, rhs, group_sizes, not transpose_rhs) + + # drhs: depends on transpose_rhs + if transpose_rhs: + # rhs shape [G, n, k], drhs[i] = grad_slice^T @ lhs_slice + drhs = _tgmm_impl(grad, lhs, group_sizes) + else: + # rhs shape [G, k, n], drhs[i] = lhs_slice^T @ grad_slice + drhs = _tgmm_impl(lhs, grad, group_sizes) + + # group_sizes is int32, not differentiable + return dlhs, drhs, None + + +@functools.partial(jax.custom_vjp, nondiff_argnums=(3,)) +def gmm( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + transpose_rhs: bool = False, +) -> jax.Array: + """Grouped matrix multiplication. JIT-compilable, runs on TPU. + + For each group i with rows [start_i, end_i): + out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + + Uses bf16 multiplication with f32 accumulation. + Differentiable via custom_vjp (forward: GMM, backward: GMM + TGMM). + + Args: + lhs: [m, k] input activations. + rhs: [num_groups, k, n] per-group weights. + If transpose_rhs=True, rhs is [num_groups, n, k]. + group_sizes: [num_groups] int32, number of rows per group. + transpose_rhs: If True, transpose each rhs[i] before matmul. + + Returns: + [m, n] output in float32. + """ + return _gmm_impl(lhs, rhs, group_sizes, transpose_rhs) + + +gmm.defvjp(_gmm_fwd, _gmm_bwd) + + +def tgmm( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, +) -> jax.Array: + """Transposed grouped matrix multiplication. JIT-compilable, runs on TPU. + + For each group i with rows [start_i, end_i): + out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] + + Uses bf16 multiplication with f32 accumulation. + + Args: + lhs: [m, k] input activations. + rhs: [m, n] second operand. + group_sizes: [num_groups] int32, number of rows per group. + + Returns: + [num_groups, k, n] per-group products in float32. + """ + return _tgmm_impl(lhs, rhs, group_sizes)