-
Notifications
You must be signed in to change notification settings - Fork 2
feat(gmm): JIT-compilable JAX GMM for TPU with forward/backward and tokamax comparison #161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
527e929
68ddcb3
f2df07c
fb5d636
ed871db
5c2587e
9e2ffa5
22b84a9
4d7af49
afa058e
addc5d8
7fd19cb
6577a37
26d10e3
6d312fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| # GMM (Grouped Matrix Multiplication) Pallas Kernel Design | ||
|
|
||
| **Date:** 2026-04-06 | ||
| **Status:** Approved | ||
| **Phase:** 1 (BF16 only, no quantization) | ||
|
|
||
| ## Goal | ||
|
|
||
| Implement a Pallas TPU kernel for Grouped Matrix Multiplication (GMM) to support MoE (Mixture-of-Experts) layers. This replaces the `tokamax` and `qwix` backends used in maxtext's megablox with a clean, self-contained implementation following tops conventions. | ||
|
|
||
| ## Semantics | ||
|
|
||
| **GMM forward:** For each expert group `i` with rows `[start_i, end_i)`: | ||
| ``` | ||
| out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i, :, :] | ||
| ``` | ||
|
|
||
| **TGMM (transposed GMM, for weight gradients):** For each group `i`: | ||
| ``` | ||
| out[i, :, :] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] | ||
| ``` | ||
|
|
||
| Tokens in `lhs` are pre-sorted by expert assignment. `group_sizes[i]` gives the number of rows belonging to expert `i`. | ||
|
|
||
| ## Public API | ||
|
|
||
| ```python | ||
| def gmm( | ||
| lhs: jnp.ndarray, # [m, k] bf16 - stacked token activations | ||
| rhs: jnp.ndarray, # [num_groups, k, n] bf16 - per-expert weights | ||
| group_sizes: jnp.ndarray, # [num_groups] int32 - token count per expert | ||
| tiling: tuple[int, int, int] = (128, 128, 128), # (tm, tk, tn) | ||
| transpose_rhs: bool = False, | ||
| preferred_element_type: jnp.dtype = jnp.float32, | ||
| ) -> jnp.ndarray: # [m, n] bf16 | ||
| ``` | ||
|
|
||
| Fully differentiable via `jax.custom_vjp`: | ||
| - **dlhs** = `gmm(grad, rhs, group_sizes, transpose_rhs=True)` | ||
| - **drhs** = `tgmm(lhs, grad, group_sizes)` | ||
|
|
||
| ### Internal: `tgmm` | ||
|
|
||
| ```python | ||
| def tgmm( | ||
| lhs: jnp.ndarray, # [m, k] bf16 | ||
| rhs: jnp.ndarray, # [m, n] bf16 | ||
| group_sizes: jnp.ndarray, # [num_groups] int32 | ||
| tiling: tuple[int, int, int] = (128, 128, 128), | ||
| preferred_element_type: jnp.dtype = jnp.float32, | ||
| ) -> jnp.ndarray: # [num_groups, k, n] bf16 | ||
| ``` | ||
|
|
||
| ## Kernel Architecture | ||
|
|
||
| ### Grid Layout (gmm) | ||
|
|
||
| ``` | ||
| grid = (tiles_n, num_active_tiles, tiles_k) | ||
| dimension_semantics = ("parallel", "arbitrary", "arbitrary") | ||
| ``` | ||
|
|
||
| - `tiles_n = n // tn` -- parallelized over output columns | ||
| - `num_active_tiles` = total m-tiles across all groups (computed by `make_group_metadata`) | ||
| - `tiles_k = k // tk` -- sequential reduction dimension | ||
|
|
||
| ### Group Metadata (computed on host, passed via scalar prefetch) | ||
|
|
||
| `make_group_metadata(group_sizes, m, tm)` produces: | ||
| - `group_offsets`: CSR-style cumulative row offsets, rounded to tm boundaries | ||
| - `group_ids`: maps each active m-tile index to its group | ||
| - `m_tile_ids`: maps each active m-tile index to its row-tile offset within the group | ||
|
|
||
|
||
| ### BlockSpecs | ||
|
|
||
| | Tensor | Block shape | Index map | | ||
| |--------|-------------|-----------| | ||
| | `lhs` | `[tm, tk]` | `(m_tile_ids[grid_m], grid_k)` | | ||
| | `rhs` | `[1, tk, tn]` | `(group_ids[grid_m], grid_k, grid_n)` | | ||
| | `out` | `[tm, tn]` | `(m_tile_ids[grid_m], grid_n)` | | ||
|
|
||
| When `transpose_rhs=True`, rhs block shape is `[1, tn, tk]`. | ||
|
|
||
| ### Kernel Body (gmm) | ||
|
|
||
| 1. Load `lhs_block [tm, tk]` and `rhs_block [tk, tn]` via BlockSpec | ||
| 2. Accumulate `dot(lhs_block, rhs_block, preferred_element_type=float32)` into VMEM scratch `[tm, tn]` | ||
| 3. On last k-tile: apply group-boundary mask (zero rows outside the group), cast to output dtype, store | ||
|
|
||
| ### Grid Layout (tgmm) | ||
|
|
||
| ``` | ||
| grid = (tiles_n, tiles_k, num_active_tiles) | ||
| dimension_semantics = ("parallel", "arbitrary", "arbitrary") | ||
| ``` | ||
|
|
||
| ### Kernel Body (tgmm) | ||
|
|
||
| For each active m-tile, accumulates `lhs_block^T [tk, tm] @ rhs_block [tm, tn]` into the output for the corresponding group. When the group changes between adjacent m-tiles, the accumulated result is stored and the accumulator reset. | ||
|
|
||
| ## Precision | ||
|
|
||
| - Input dtype: bf16 | ||
| - Accumulation: float32 (via VMEM scratch) | ||
| - Output: cast back to bf16 | ||
| - `jax.lax.Precision.HIGHEST` on all dot products | ||
|
|
||
| ## File Layout | ||
|
|
||
| ``` | ||
| tops/ops/gmm/ | ||
| __init__.py # Public API: gmm() | ||
| gmm.py # Pallas kernels + custom_vjp + tgmm | ||
| metadata.py # make_group_metadata() | ||
|
|
||
| tops/cpu/ops/gmm/ | ||
| __init__.py | ||
| naive.py # Pure JAX reference implementation | ||
|
|
||
| tests/ops/gmm/ | ||
| test_gmm_tpu.py # Pallas vs CPU reference tests | ||
| conftest.py # GMM test fixtures | ||
| ``` | ||
|
|
||
| ## Testing Strategy | ||
|
|
||
| 1. **CPU reference:** Pure JAX loop over groups with plain matmul | ||
| 2. **Forward test:** Compare Pallas gmm output vs reference across configs | ||
| 3. **Gradient test:** Compare custom_vjp gradients vs `jax.grad` of reference | ||
| 4. **Configs:** Vary (m, k, n, num_groups) with distributions: | ||
| - Uniform group sizes | ||
| - Skewed (one large group, many small) | ||
| - Single group (degenerates to plain matmul) | ||
| - Empty groups (group_size=0) | ||
| - Sizes not divisible by tm | ||
| 5. **Tolerances:** atol ~1e-2, rtol ~1e-2 (bf16 accumulation) | ||
|
|
||
| ## Future Work (Phase 2+) | ||
|
|
||
| - Block-wise quantization: (128,128) for weights, (1,128) for activations | ||
| - `group_offset` for expert parallelism / sharded groups | ||
| - `existing_out` for accumulation into pre-existing buffers | ||
| - Double/triple buffering (`input_buffer_count`) | ||
| - Async DMA pipelining | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add language identifiers to these fenced blocks.
markdownlint is already flagging these examples under MD040. Adding
text/pythonlanguage tags will clear the warning and improve rendering.Also applies to: 19-21, 58-61, 92-95, 110-123
🧰 Tools
🪛 markdownlint-cli2 (0.22.0)
[warning] 14-14: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents