Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
76b5444
feat(gla): add fused forward kernel for g_gamma mode
sii-xinglong Mar 31, 2026
71a8429
feat(gla): add fused backward kernel for g_gamma mode
sii-xinglong Mar 31, 2026
22c6c54
style(chunk_fused_kernels): use exp() utility consistently in backwar…
sii-xinglong Mar 31, 2026
05c6726
feat(gla): dispatch to fused kernels for g_gamma mode on TPU
sii-xinglong Mar 31, 2026
d8d3523
bench(gla): add chunk_fused and chunk_fused_bwd benchmark providers
sii-xinglong Mar 31, 2026
3a88c91
fix(chunk_fused_kernels): cast float32 scratch to h_ref dtype on save
sii-xinglong Mar 31, 2026
3e0372b
fix(chunk_fused_kernels): make scale a static JIT argument
sii-xinglong Mar 31, 2026
572ed36
fix(chunk_fused_kernels): use parallel semantics for K/V grid dims
sii-xinglong Mar 31, 2026
2b71a64
fix(chunk_fused_kernels): use f32 inputs for all matmuls in Pallas ke…
sii-xinglong Mar 31, 2026
1247d97
fix(gla): correct cu_seqlens parameter names in chunk kernel calls
sii-xinglong Mar 31, 2026
66c19d8
fix(tests): use bfloat16 inputs for both fused and reference paths
sii-xinglong Mar 31, 2026
f15840f
fix(tests): use naive recurrent reference instead of non-fused Pallas
sii-xinglong Mar 31, 2026
7ffeae2
fix: scale g_gamma in E2E test to prevent numerical overflow
sii-xinglong Mar 31, 2026
81458e0
fix: address PR #122 review comments
sii-xinglong Mar 31, 2026
7e9d63f
perf(chunk_gla): eliminate jnp.flip copies in fused backward via reve…
sii-xinglong Mar 31, 2026
1de5415
bench(gla): add memory profiling script for fused vs non-fused
sii-xinglong Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions benchmarks/ops/benchmark_gla.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@
"fused_recurrent",
"chunk",
"fused_chunk",
"chunk_fused",
"simple_gla_naive",
"simple_gla_chunk",
"fused_recurrent_bwd",
"chunk_bwd",
"fused_chunk_bwd",
"chunk_fused_bwd",
"simple_gla_chunk_bwd",
]

Expand Down Expand Up @@ -170,6 +172,36 @@ def loss_fn(q_, k_, v_, gk_):
return jax.grad(loss_fn, argnums=(0, 1, 2, 3))(q, k, v, gk)

fn = fwd_bwd
elif provider == "chunk_fused":
from tops.ops.gla.chunk_fused_kernels import chunk_fwd_fused_g_gamma
chunk_size = 64
if T % chunk_size != 0 or D % 128 != 0:
return None
fn = partial(
chunk_fwd_fused_g_gamma,
q, k, v, g_gamma, scale=scale, chunk_size=chunk_size,
)
elif provider == "chunk_fused_bwd":
from tops.ops.gla.chunk_fused_kernels import (
chunk_fwd_fused_g_gamma,
chunk_bwd_fused_g_gamma,
)
chunk_size = 64
if T % chunk_size != 0 or D % 128 != 0:
return None

h_fwd, o_fwd = chunk_fwd_fused_g_gamma(
q, k, v, g_gamma, scale=scale, chunk_size=chunk_size,
)
do = jnp.ones_like(o_fwd)

@jax.jit
def run_bwd():
return chunk_bwd_fused_g_gamma(
q, k, v, h_fwd, do, g_gamma, scale=scale, chunk_size=chunk_size,
)

fn = run_bwd
elif provider == "simple_gla_naive":
if T > NAIVE_MAX_T:
return None
Expand Down
123 changes: 123 additions & 0 deletions tests/ops/gla/test_pallas_chunk_fused_bwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Tests for the fused chunked GLA backward kernel (g_gamma mode).

Compares the output (dq, dk, dv) of the fused single-pallas_call backward
against the non-fused ``chunk_gla_bwd_with_pl`` reference that uses separate
dh propagation and gradient computation kernels.
"""

from __future__ import annotations

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[3]))

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from tests.utils import compare_tensor
from tops.ops.gla.chunk import chunk_gla_bwd_with_pl
from tops.ops.gla.chunk_fused_kernels import (
chunk_bwd_fused_g_gamma,
chunk_fwd_fused_g_gamma,
)


def _make_test_data(B, T, H, K, V, seed=42):
"""Create deterministic (q, k, v, g_gamma, do) for a GLA backward test."""
key = jax.random.PRNGKey(seed)
k1, k2, k3, k4, k5 = jax.random.split(key, 5)
q = jax.random.normal(k1, (B, T, H, K), dtype=jnp.bfloat16)
k_arr = jax.random.normal(k2, (B, T, H, K), dtype=jnp.bfloat16)
v = jax.random.normal(k3, (B, T, H, V), dtype=jnp.bfloat16)
g_gamma = -jnp.abs(jax.random.normal(k4, (H,), dtype=jnp.float32)) * 0.1
do = jax.random.normal(k5, (B, T, H, V), dtype=jnp.bfloat16)
return q, k_arr, v, g_gamma, do


@pytest.mark.tpu_only
class TestChunkBwdFused:
"""Tests for chunk_bwd_fused_g_gamma against the non-fused reference."""

def _run_reference(self, q, k, v, g_gamma, do, scale, chunk_size):
"""Run the non-fused chunk_gla_bwd_with_pl to get reference dq, dk, dv."""
q_f32 = q.astype(jnp.float32)
k_f32 = k.astype(jnp.float32)
v_f32 = v.astype(jnp.float32)
do_f32 = do.astype(jnp.float32)
ref_dq, ref_dk, ref_dv, _, _ = chunk_gla_bwd_with_pl(
q_f32,
k_f32,
v_f32,
g=None,
g_gamma=g_gamma.reshape(1, 1, -1, 1),
g_cumsum=None,
scale=scale,
initial_state=None,
h=None,
A=None,
do=do_f32,
dht=None,
chunk_size=chunk_size,
)
return ref_dq, ref_dk, ref_dv

def test_fused_bwd_basic(self):
"""Basic fused backward: B=2, T=256, H=4, K=128, V=128, C=64."""
B, T, H, K, V, C = 2, 256, 4, 128, 128, 64
q, k, v, g_gamma, do = _make_test_data(B, T, H, K, V, seed=42)
scale = K**-0.5

# Get h from fused forward
h, _ = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C)

# Fused backward kernel
dq_fused, dk_fused, dv_fused = chunk_bwd_fused_g_gamma(
q, k, v, h, do, g_gamma, scale, C
)

# Non-fused reference
ref_dq, ref_dk, ref_dv = self._run_reference(q, k, v, g_gamma, do, scale, C)

assert compare_tensor(
"dq", ref_dq, dq_fused, atol=1e-2, rtol=1e-2, dtype=np.float32
)
assert compare_tensor(
"dk", ref_dk, dk_fused, atol=1e-2, rtol=1e-2, dtype=np.float32
)
assert compare_tensor(
"dv", ref_dv, dv_fused, atol=1e-2, rtol=1e-2, dtype=np.float32
)

def test_fused_bwd_al_dims(self):
"""AL model dimensions: B=2, T=4096, H=16, K=128, V=128, C=64."""
B, T, H, K, V, C = 2, 4096, 16, 128, 128, 64
q, k, v, g_gamma, do = _make_test_data(B, T, H, K, V, seed=123)
scale = K**-0.5

# Get h from fused forward
h, _ = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C)

# Fused backward kernel
dq_fused, dk_fused, dv_fused = chunk_bwd_fused_g_gamma(
q, k, v, h, do, g_gamma, scale, C
)

# Non-fused reference
ref_dq, ref_dk, ref_dv = self._run_reference(q, k, v, g_gamma, do, scale, C)

assert compare_tensor(
"dq", ref_dq, dq_fused, atol=1e-2, rtol=1e-2, dtype=np.float32
)
assert compare_tensor(
"dk", ref_dk, dk_fused, atol=1e-2, rtol=1e-2, dtype=np.float32
)
assert compare_tensor(
"dv", ref_dv, dv_fused, atol=1e-2, rtol=1e-2, dtype=np.float32
)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
91 changes: 91 additions & 0 deletions tests/ops/gla/test_pallas_chunk_fused_fwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Tests for the fused chunked GLA forward kernel (g_gamma mode).

Compares the output of the fused single-pallas_call forward against the
non-fused ``chunk_gla_fwd`` reference that uses three separate kernels.
"""

from __future__ import annotations

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[3]))

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from tests.utils import compare_tensor
from tops.ops.gla.chunk import chunk_gla_fwd
from tops.ops.gla.chunk_fused_kernels import chunk_fwd_fused_g_gamma


def _make_test_data(B, T, H, K, V, seed=42):
"""Create deterministic (q, k, v, g_gamma) for a GLA test case."""
key = jax.random.PRNGKey(seed)
k1, k2, k3, k4 = jax.random.split(key, 4)
q = jax.random.normal(k1, (B, T, H, K), dtype=jnp.bfloat16)
k_arr = jax.random.normal(k2, (B, T, H, K), dtype=jnp.bfloat16)
v = jax.random.normal(k3, (B, T, H, V), dtype=jnp.bfloat16)
g_gamma = -jnp.abs(jax.random.normal(k4, (H,), dtype=jnp.float32)) * 0.1
return q, k_arr, v, g_gamma


@pytest.mark.tpu_only
class TestChunkFwdFused:
"""Tests for chunk_fwd_fused_g_gamma against the non-fused reference."""

def _run_reference(self, q, k, v, g_gamma, scale, chunk_size):
"""Run the non-fused chunk_gla_fwd to get reference (h, o)."""
q_f32 = q.astype(jnp.float32)
k_f32 = k.astype(jnp.float32)
v_f32 = v.astype(jnp.float32)
_, _, h_ref, _, o_ref = chunk_gla_fwd(
q_f32,
k_f32,
v_f32,
g=None,
g_gamma=g_gamma.reshape(1, 1, -1, 1),
g_cumsum=None,
scale=scale,
initial_state=None,
output_final_state=False,
cu_seqlens=None,
chunk_size=chunk_size,
)
return h_ref, o_ref

def test_fused_fwd_basic(self):
"""Basic fused forward: B=2, T=256, H=4, K=128, V=128, C=64."""
B, T, H, K, V, C = 2, 256, 4, 128, 128, 64
q, k, v, g_gamma = _make_test_data(B, T, H, K, V, seed=42)
scale = K**-0.5

# Fused kernel
h_fused, o_fused = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C)

# Non-fused reference
h_ref, o_ref = self._run_reference(q, k, v, g_gamma, scale, C)

assert compare_tensor("o", o_ref, o_fused, atol=1e-2, rtol=1e-2, dtype=np.float32)
assert compare_tensor("h", h_ref, h_fused, atol=1e-2, rtol=1e-2, dtype=np.float32)

def test_fused_fwd_al_dims(self):
"""AL model dimensions: B=2, T=4096, H=16, K=128, V=128, C=64."""
B, T, H, K, V, C = 2, 4096, 16, 128, 128, 64
q, k, v, g_gamma = _make_test_data(B, T, H, K, V, seed=123)
scale = K**-0.5

# Fused kernel
h_fused, o_fused = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C)

# Non-fused reference
h_ref, o_ref = self._run_reference(q, k, v, g_gamma, scale, C)

assert compare_tensor("o", o_ref, o_fused, atol=1e-2, rtol=1e-2, dtype=np.float32)
assert compare_tensor("h", h_ref, h_fused, atol=1e-2, rtol=1e-2, dtype=np.float32)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
106 changes: 106 additions & 0 deletions tests/ops/gla/test_pallas_chunk_gla_fused_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""End-to-end tests for fused chunk-GLA dispatch path (g_gamma mode).

Verifies that ``chunk_gla_fwd`` and ``chunk_gla_bwd_with_pl`` correctly
dispatch to the fused kernels when running in g_gamma mode on TPU, and
that the results match the non-fused reference path.
"""

from __future__ import annotations

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[3]))

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from tests.utils import compare_tensor
from tops.ops.gla.chunk import chunk_gla_fwd, chunk_gla_bwd_with_pl


def _make_test_data(B, T, H, K, V, seed=42):
"""Create deterministic (q, k, v, g_gamma, do) for a GLA test case."""
key = jax.random.PRNGKey(seed)
k1, k2, k3, k4, k5 = jax.random.split(key, 5)
q = jax.random.normal(k1, (B, T, H, K), dtype=jnp.bfloat16)
k_arr = jax.random.normal(k2, (B, T, H, K), dtype=jnp.bfloat16)
v = jax.random.normal(k3, (B, T, H, V), dtype=jnp.bfloat16)
g_gamma = -jnp.abs(jax.random.normal(k4, (H,), dtype=jnp.float32)) * 0.1
do = jax.random.normal(k5, (B, T, H, V), dtype=jnp.bfloat16)
return q, k_arr, v, g_gamma, do


@pytest.mark.tpu_only
class TestChunkGlaFusedE2E:
"""End-to-end fused dispatch vs non-fused reference."""

def test_fwd_dispatch_output_shape(self):
"""Forward dispatch produces correct output shapes."""
B, T, H, K, V, C = 2, 256, 4, 128, 128, 64
q, k, v, g_gamma, _ = _make_test_data(B, T, H, K, V)
scale = K**-0.5

_, _, h, ht, o = chunk_gla_fwd(
q, k, v, g=None, g_gamma=g_gamma, g_cumsum=None,
scale=scale, initial_state=None, output_final_state=False,
chunk_size=C,
)
assert o.shape == (B, T, H, V)
assert h.shape[0] == B
assert ht is None

def test_fwd_dispatch_matches_reference(self):
"""Forward fused dispatch matches non-fused reference on output."""
B, T, H, K, V, C = 2, 256, 4, 128, 128, 64
q, k, v, g_gamma, _ = _make_test_data(B, T, H, K, V)
scale = K**-0.5

# Run via the dispatch path (g_gamma mode, should use fused on TPU)
_, _, h_disp, _, o_disp = chunk_gla_fwd(
q, k, v, g=None, g_gamma=g_gamma, g_cumsum=None,
scale=scale, initial_state=None, output_final_state=False,
chunk_size=C,
)

# Run via the non-fused path (pass g explicitly to bypass fused dispatch)
g_full = jnp.broadcast_to(g_gamma.reshape(1, 1, -1, 1), q.shape)
_, _, h_ref, _, o_ref = chunk_gla_fwd(
q, k, v, g=g_full, g_gamma=g_gamma, g_cumsum=None,
scale=scale, initial_state=None, output_final_state=False,
chunk_size=C,
)

assert compare_tensor("o", o_ref, o_disp, atol=1e-2, rtol=1e-2, dtype=np.float32)
assert compare_tensor("h", h_ref, h_disp, atol=1e-2, rtol=1e-2, dtype=np.float32)

def test_bwd_dispatch_matches_reference(self):
"""Backward fused dispatch matches non-fused reference gradients."""
B, T, H, K, V, C = 2, 256, 4, 128, 128, 64
q, k, v, g_gamma, do = _make_test_data(B, T, H, K, V)
scale = K**-0.5

# Backward via dispatch path (g_gamma mode, should use fused on TPU)
dq_disp, dk_disp, dv_disp, _, _ = chunk_gla_bwd_with_pl(
q, k, v, g=None, g_gamma=g_gamma, g_cumsum=None,
scale=scale, initial_state=None, h=None, A=None,
do=do, dht=None, chunk_size=C,
)

# Backward via non-fused path (pass g explicitly)
g_full = jnp.broadcast_to(g_gamma.reshape(1, 1, -1, 1), q.shape)
dq_ref, dk_ref, dv_ref, _, _ = chunk_gla_bwd_with_pl(
q, k, v, g=g_full, g_gamma=g_gamma, g_cumsum=None,
scale=scale, initial_state=None, h=None, A=None,
do=do, dht=None, chunk_size=C,
)

assert compare_tensor("dq", dq_ref, dq_disp, atol=1e-2, rtol=1e-2, dtype=np.float32)
assert compare_tensor("dk", dk_ref, dk_disp, atol=1e-2, rtol=1e-2, dtype=np.float32)
assert compare_tensor("dv", dv_ref, dv_disp, atol=1e-2, rtol=1e-2, dtype=np.float32)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading
Loading