Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions docs/plans/2026-04-06-gmm-kernel-design.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# GMM (Grouped Matrix Multiplication) Pallas Kernel Design

**Date:** 2026-04-06
**Status:** Approved
**Phase:** 1 (BF16 only, no quantization)

## Goal

Implement a Pallas TPU kernel for Grouped Matrix Multiplication (GMM) to support MoE (Mixture-of-Experts) layers. This replaces the `tokamax` and `qwix` backends used in maxtext's megablox with a clean, self-contained implementation following tops conventions.

## Semantics

**GMM forward:** For each expert group `i` with rows `[start_i, end_i)`:
```
out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i, :, :]
```
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add language identifiers to these fenced blocks.

markdownlint is already flagging these examples under MD040. Adding text/python language tags will clear the warning and improve rendering.

Also applies to: 19-21, 58-61, 92-95, 110-123

🧰 Tools
🪛 markdownlint-cli2 (0.22.0)

[warning] 14-14: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/plans/2026-04-06-gmm-kernel-design.md` around lines 14 - 16, The fenced
code examples (e.g., the block containing the expression "out[start_i:end_i, :]
= lhs[start_i:end_i, :] @ rhs[i, :, :]" and the other similar fenced blocks)
lack language identifiers causing markdownlint MD040; update each
triple-backtick fence to include an explicit language tag such as ```python (or
```text if it's plain output) so the blocks are properly classified and
rendered; ensure you update all occurrences mentioned (the blocks around the
example expression and the other fenced regions referenced) to use the same
convention.


**TGMM (transposed GMM, for weight gradients):** For each group `i`:
```
out[i, :, :] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :]
```

Tokens in `lhs` are pre-sorted by expert assignment. `group_sizes[i]` gives the number of rows belonging to expert `i`.

## Public API

```python
def gmm(
lhs: jnp.ndarray, # [m, k] bf16 - stacked token activations
rhs: jnp.ndarray, # [num_groups, k, n] bf16 - per-expert weights
group_sizes: jnp.ndarray, # [num_groups] int32 - token count per expert
tiling: tuple[int, int, int] = (128, 128, 128), # (tm, tk, tn)
transpose_rhs: bool = False,
preferred_element_type: jnp.dtype = jnp.float32,
) -> jnp.ndarray: # [m, n] bf16
```

Fully differentiable via `jax.custom_vjp`:
- **dlhs** = `gmm(grad, rhs, group_sizes, transpose_rhs=True)`
- **drhs** = `tgmm(lhs, grad, group_sizes)`

### Internal: `tgmm`

```python
def tgmm(
lhs: jnp.ndarray, # [m, k] bf16
rhs: jnp.ndarray, # [m, n] bf16
group_sizes: jnp.ndarray, # [num_groups] int32
tiling: tuple[int, int, int] = (128, 128, 128),
preferred_element_type: jnp.dtype = jnp.float32,
) -> jnp.ndarray: # [num_groups, k, n] bf16
```

## Kernel Architecture

### Grid Layout (gmm)

```
grid = (tiles_n, num_active_tiles, tiles_k)
dimension_semantics = ("parallel", "arbitrary", "arbitrary")
```

- `tiles_n = n // tn` -- parallelized over output columns
- `num_active_tiles` = total m-tiles across all groups (computed by `make_group_metadata`)
- `tiles_k = k // tk` -- sequential reduction dimension

### Group Metadata (computed on host, passed via scalar prefetch)

`make_group_metadata(group_sizes, m, tm)` produces:
- `group_offsets`: CSR-style cumulative row offsets, rounded to tm boundaries
- `group_ids`: maps each active m-tile index to its group
- `m_tile_ids`: maps each active m-tile index to its row-tile offset within the group

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

The metadata contract here is out of sync with make_group_metadata.

In tops/ops/gmm/metadata.py, group_offsets stays as the plain cumulative sum of unpadded group_sizes, and m_tile_ids are global row-tile ids. Describing them here as tm-rounded offsets and per-group tile offsets documents a different scheduler than the one in code.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/plans/2026-04-06-gmm-kernel-design.md` around lines 67 - 73, The
docstring in the design doc is inconsistent with the actual implementation in
tops/ops/gmm/metadata.py: update the description of make_group_metadata to match
the code (or vice versa). Specifically, state that group_offsets is the plain
cumulative sum of unpadded group_sizes (not tm-rounded offsets) and that
m_tile_ids are global row-tile ids (not per-group offsets); keep group_ids
semantics as in the code. Reference make_group_metadata, group_offsets,
group_ids, and m_tile_ids when making the doc edit so the doc and
tops/ops/gmm/metadata.py agree.

### BlockSpecs

| Tensor | Block shape | Index map |
|--------|-------------|-----------|
| `lhs` | `[tm, tk]` | `(m_tile_ids[grid_m], grid_k)` |
| `rhs` | `[1, tk, tn]` | `(group_ids[grid_m], grid_k, grid_n)` |
| `out` | `[tm, tn]` | `(m_tile_ids[grid_m], grid_n)` |

When `transpose_rhs=True`, rhs block shape is `[1, tn, tk]`.

### Kernel Body (gmm)

1. Load `lhs_block [tm, tk]` and `rhs_block [tk, tn]` via BlockSpec
2. Accumulate `dot(lhs_block, rhs_block, preferred_element_type=float32)` into VMEM scratch `[tm, tn]`
3. On last k-tile: apply group-boundary mask (zero rows outside the group), cast to output dtype, store

### Grid Layout (tgmm)

```
grid = (tiles_n, tiles_k, num_active_tiles)
dimension_semantics = ("parallel", "arbitrary", "arbitrary")
```

### Kernel Body (tgmm)

For each active m-tile, accumulates `lhs_block^T [tk, tm] @ rhs_block [tm, tn]` into the output for the corresponding group. When the group changes between adjacent m-tiles, the accumulated result is stored and the accumulator reset.

## Precision

- Input dtype: bf16
- Accumulation: float32 (via VMEM scratch)
- Output: cast back to bf16
- `jax.lax.Precision.HIGHEST` on all dot products

## File Layout

```
tops/ops/gmm/
__init__.py # Public API: gmm()
gmm.py # Pallas kernels + custom_vjp + tgmm
metadata.py # make_group_metadata()

tops/cpu/ops/gmm/
__init__.py
naive.py # Pure JAX reference implementation

tests/ops/gmm/
test_gmm_tpu.py # Pallas vs CPU reference tests
conftest.py # GMM test fixtures
```

## Testing Strategy

1. **CPU reference:** Pure JAX loop over groups with plain matmul
2. **Forward test:** Compare Pallas gmm output vs reference across configs
3. **Gradient test:** Compare custom_vjp gradients vs `jax.grad` of reference
4. **Configs:** Vary (m, k, n, num_groups) with distributions:
- Uniform group sizes
- Skewed (one large group, many small)
- Single group (degenerates to plain matmul)
- Empty groups (group_size=0)
- Sizes not divisible by tm
5. **Tolerances:** atol ~1e-2, rtol ~1e-2 (bf16 accumulation)

## Future Work (Phase 2+)

- Block-wise quantization: (128,128) for weights, (1,128) for activations
- `group_offset` for expert parallelism / sharded groups
- `existing_out` for accumulation into pre-existing buffers
- Double/triple buffering (`input_buffer_count`)
- Async DMA pipelining
Loading
Loading