Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ gpu = [
tpu = [
"jax[tpu]>=0.8.1",
"torch",
"tokamax>=0.0.12",
]
profile = [
"xprof==2.22.0",
Expand Down
Empty file added tests/ops/gmm/__init__.py
Empty file.
222 changes: 222 additions & 0 deletions tests/ops/gmm/test_cpu_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
"""Verify CPU reference implementations for GMM/TGMM are correct."""

from __future__ import annotations

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[3]))

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from tops.cpu.ops.gmm import gmm_ref, tgmm_ref
from tops.ops.gmm import gmm, tgmm


class TestGmmRef:
"""Test gmm_ref against manual numpy computation."""

def test_single_group(self):
"""Single group = standard matmul."""
lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16)
rhs = jnp.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=jnp.bfloat16)
gs = jnp.array([2], dtype=jnp.int32)
out = gmm_ref(lhs, rhs, gs)
expected = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32)
np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2)

def test_two_groups(self):
"""Two groups with different weights."""
lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.bfloat16)
rhs = jnp.array(
[
[[2.0, 0.0], [0.0, 2.0]], # group 0: scale by 2
[[0.0, 1.0], [1.0, 0.0]], # group 1: swap columns
],
dtype=jnp.bfloat16,
)
gs = jnp.array([1, 2], dtype=jnp.int32)
out = gmm_ref(lhs, rhs, gs)
expected = jnp.array([[2.0, 0.0], [1.0, 0.0], [1.0, 1.0]], dtype=jnp.float32)
np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2)

def test_empty_group(self):
"""Empty group produces zeros for those rows (none exist)."""
lhs = jnp.array([[1.0, 2.0]], dtype=jnp.bfloat16)
rhs = jnp.array(
[
[[1.0], [1.0]], # group 0: empty
[[1.0], [1.0]], # group 1: 1 row
],
dtype=jnp.bfloat16,
)
gs = jnp.array([0, 1], dtype=jnp.int32)
out = gmm_ref(lhs, rhs, gs)
expected = jnp.array([[3.0]], dtype=jnp.float32)
np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2)

def test_transpose_rhs(self):
"""transpose_rhs transposes each rhs[i] before matmul."""
lhs = jnp.array([[1.0, 2.0]], dtype=jnp.bfloat16)
rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.bfloat16)
gs = jnp.array([1], dtype=jnp.int32)
# Without transpose: lhs [1,2] @ rhs [2,2] = [1*3+2*5, 1*4+2*6] = [13, 16]
out_normal = gmm_ref(lhs, rhs, gs)
np.testing.assert_allclose(np.array(out_normal), [[13.0, 16.0]], atol=1e-2)
# With transpose: lhs [1,2] @ rhs.T [2,2] = [1*3+2*4, 1*5+2*6] = [11, 17]
out_transposed = gmm_ref(lhs, rhs, gs, transpose_rhs=True)
np.testing.assert_allclose(np.array(out_transposed), [[11.0, 17.0]], atol=1e-2)


class TestTgmmRef:
"""Test tgmm_ref against manual numpy computation."""

def test_single_group(self):
"""Single group: lhs^T @ rhs."""
lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16)
rhs = jnp.array([[5.0], [6.0]], dtype=jnp.bfloat16)
gs = jnp.array([2], dtype=jnp.int32)
out = tgmm_ref(lhs, rhs, gs)
# lhs^T [2,2] @ rhs [2,1] = [[1*5+3*6], [2*5+4*6]] = [[23], [34]]
expected = jnp.array([[[23.0], [34.0]]], dtype=jnp.float32)
np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2)

def test_two_groups(self):
"""Two groups produce separate outer products."""
lhs = jnp.array([[1.0], [2.0], [3.0]], dtype=jnp.bfloat16)
rhs = jnp.array([[4.0], [5.0], [6.0]], dtype=jnp.bfloat16)
gs = jnp.array([1, 2], dtype=jnp.int32)
out = tgmm_ref(lhs, rhs, gs)
# Group 0: [1]^T @ [4] = [[4]]
# Group 1: [2,3]^T @ [5,6] = [[2*5+3*6]] = [[28]]
expected = jnp.array([[[4.0]], [[28.0]]], dtype=jnp.float32)
np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2)


class TestGmmJax:
"""Test JIT-compilable gmm against CPU reference."""

def test_single_group(self):
"""Single group: gmm matches gmm_ref."""
lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16)
rhs = jnp.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=jnp.bfloat16)
gs = jnp.array([2], dtype=jnp.int32)
out = gmm(lhs, rhs, gs)
ref = gmm_ref(lhs, rhs, gs)
np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2)

def test_two_groups(self):
"""Two groups: gmm matches gmm_ref."""
lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.bfloat16)
rhs = jnp.array(
[
[[2.0, 0.0], [0.0, 2.0]],
[[0.0, 1.0], [1.0, 0.0]],
],
dtype=jnp.bfloat16,
)
gs = jnp.array([1, 2], dtype=jnp.int32)
out = gmm(lhs, rhs, gs)
ref = gmm_ref(lhs, rhs, gs)
np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2)

def test_transpose_rhs(self):
"""transpose_rhs: gmm matches gmm_ref."""
lhs = jnp.array([[1.0, 2.0]], dtype=jnp.bfloat16)
rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.bfloat16)
gs = jnp.array([1], dtype=jnp.int32)
out = gmm(lhs, rhs, gs, transpose_rhs=True)
ref = gmm_ref(lhs, rhs, gs, transpose_rhs=True)
np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2)


class TestTgmmJax:
"""Test JIT-compilable tgmm against CPU reference."""

def test_single_group(self):
"""Single group: tgmm matches tgmm_ref."""
lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16)
rhs = jnp.array([[5.0], [6.0]], dtype=jnp.bfloat16)
gs = jnp.array([2], dtype=jnp.int32)
out = tgmm(lhs, rhs, gs)
ref = tgmm_ref(lhs, rhs, gs)
np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2)

def test_two_groups(self):
"""Two groups: tgmm matches tgmm_ref."""
lhs = jnp.array([[1.0], [2.0], [3.0]], dtype=jnp.bfloat16)
rhs = jnp.array([[4.0], [5.0], [6.0]], dtype=jnp.bfloat16)
gs = jnp.array([1, 2], dtype=jnp.int32)
out = tgmm(lhs, rhs, gs)
ref = tgmm_ref(lhs, rhs, gs)
np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2)


class TestGmmGrad:
"""Test gmm gradient via custom_vjp."""

def test_grad_lhs(self):
"""Gradient w.r.t. lhs is correct."""
lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16)
rhs = jnp.array([[[1.0, 0.5], [0.5, 1.0]]], dtype=jnp.bfloat16)
gs = jnp.array([2], dtype=jnp.int32)

def loss_fn(x):
return gmm(x, rhs, gs).sum()

grad = jax.grad(loss_fn)(lhs)
assert grad.shape == lhs.shape
assert not jnp.any(jnp.isnan(grad))

def test_grad_rhs(self):
"""Gradient w.r.t. rhs is correct."""
lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16)
rhs = jnp.array([[[1.0, 0.5], [0.5, 1.0]]], dtype=jnp.bfloat16)
gs = jnp.array([2], dtype=jnp.int32)

def loss_fn(w):
return gmm(lhs, w, gs).sum()

grad = jax.grad(loss_fn)(rhs)
assert grad.shape == rhs.shape
assert not jnp.any(jnp.isnan(grad))

def test_grad_both(self):
"""Gradient w.r.t. both lhs and rhs, two groups."""
lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.bfloat16)
rhs = jnp.array(
[[[2.0, 0.0], [0.0, 2.0]], [[0.0, 1.0], [1.0, 0.0]]],
dtype=jnp.bfloat16,
)
gs = jnp.array([1, 2], dtype=jnp.int32)

def loss_fn(x, w):
return gmm(x, w, gs).sum()

grad_lhs, grad_rhs = jax.grad(loss_fn, argnums=(0, 1))(lhs, rhs)
assert grad_lhs.shape == lhs.shape
assert grad_rhs.shape == rhs.shape
assert not jnp.any(jnp.isnan(grad_lhs))
assert not jnp.any(jnp.isnan(grad_rhs))

def test_grad_transpose_rhs(self):
"""Gradient with transpose_rhs=True."""
lhs = jnp.array([[1.0, 2.0]], dtype=jnp.bfloat16)
rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.bfloat16)
gs = jnp.array([1], dtype=jnp.int32)

def loss_fn(x, w):
return gmm(x, w, gs, transpose_rhs=True).sum()

grad_lhs, grad_rhs = jax.grad(loss_fn, argnums=(0, 1))(lhs, rhs)
assert grad_lhs.shape == lhs.shape
assert grad_rhs.shape == rhs.shape
assert not jnp.any(jnp.isnan(grad_lhs))
assert not jnp.any(jnp.isnan(grad_rhs))


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading
Loading