Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
76b5444
feat(gla): add fused forward kernel for g_gamma mode
sii-xinglong Mar 31, 2026
71a8429
feat(gla): add fused backward kernel for g_gamma mode
sii-xinglong Mar 31, 2026
22c6c54
style(chunk_fused_kernels): use exp() utility consistently in backwar…
sii-xinglong Mar 31, 2026
05c6726
feat(gla): dispatch to fused kernels for g_gamma mode on TPU
sii-xinglong Mar 31, 2026
d8d3523
bench(gla): add chunk_fused and chunk_fused_bwd benchmark providers
sii-xinglong Mar 31, 2026
3a88c91
fix(chunk_fused_kernels): cast float32 scratch to h_ref dtype on save
sii-xinglong Mar 31, 2026
3e0372b
fix(chunk_fused_kernels): make scale a static JIT argument
sii-xinglong Mar 31, 2026
572ed36
fix(chunk_fused_kernels): use parallel semantics for K/V grid dims
sii-xinglong Mar 31, 2026
2b71a64
fix(chunk_fused_kernels): use f32 inputs for all matmuls in Pallas ke…
sii-xinglong Mar 31, 2026
1247d97
fix(gla): correct cu_seqlens parameter names in chunk kernel calls
sii-xinglong Mar 31, 2026
66c19d8
fix(tests): use bfloat16 inputs for both fused and reference paths
sii-xinglong Mar 31, 2026
f15840f
fix(tests): use naive recurrent reference instead of non-fused Pallas
sii-xinglong Mar 31, 2026
7ffeae2
fix: scale g_gamma in E2E test to prevent numerical overflow
sii-xinglong Mar 31, 2026
81458e0
fix: address PR #122 review comments
sii-xinglong Mar 31, 2026
7e9d63f
perf(chunk_gla): eliminate jnp.flip copies in fused backward via reve…
sii-xinglong Mar 31, 2026
1de5415
bench(gla): add memory profiling script for fused vs non-fused
sii-xinglong Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions benchmarks/ops/benchmark_gla.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@
"fused_recurrent",
"chunk",
"fused_chunk",
"chunk_fused",
"simple_gla_naive",
"simple_gla_chunk",
"fused_recurrent_bwd",
"chunk_bwd",
"fused_chunk_bwd",
"chunk_fused_bwd",
"simple_gla_chunk_bwd",
]

Expand Down Expand Up @@ -170,6 +172,36 @@ def loss_fn(q_, k_, v_, gk_):
return jax.grad(loss_fn, argnums=(0, 1, 2, 3))(q, k, v, gk)

fn = fwd_bwd
elif provider == "chunk_fused":
from tops.ops.gla.chunk_fused_kernels import chunk_fwd_fused_g_gamma
chunk_size = 64
if T % chunk_size != 0 or D % 128 != 0:
return None
fn = partial(
chunk_fwd_fused_g_gamma,
q, k, v, g_gamma, scale=scale, chunk_size=chunk_size,
)
elif provider == "chunk_fused_bwd":
from tops.ops.gla.chunk_fused_kernels import (
chunk_fwd_fused_g_gamma,
chunk_bwd_fused_g_gamma,
)
chunk_size = 64
if T % chunk_size != 0 or D % 128 != 0:
return None

h_fwd, o_fwd = chunk_fwd_fused_g_gamma(
q, k, v, g_gamma, scale=scale, chunk_size=chunk_size,
)
do = jnp.ones_like(o_fwd)

@jax.jit
def run_bwd():
return chunk_bwd_fused_g_gamma(
q, k, v, h_fwd, do, g_gamma, scale=scale, chunk_size=chunk_size,
)

fn = run_bwd
elif provider == "simple_gla_naive":
if T > NAIVE_MAX_T:
return None
Expand Down
211 changes: 211 additions & 0 deletions benchmarks/ops/profile_gla_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
"""Memory profiling: fused vs non-fused chunk-GLA (forward + backward).

Analyzes HBM footprint including activation retention between fwd and bwd.

XLA copy semantics:
- transpose() / reshape(): logical view, zero copy
- broadcast_to(): logical view, zero copy (but materialized when fed to pallas)

Usage:
python benchmarks/ops/profile_gla_memory.py
python benchmarks/ops/profile_gla_memory.py --B 2 --T 4096 --H 16 --K 128 --V 128 --C 64
"""

from __future__ import annotations

import argparse


def _bytes(shape: tuple[int, ...], elem_bytes: int) -> int:
r = elem_bytes
for s in shape:
r *= s
return r


def _fmt(nbytes: int) -> str:
if nbytes < 0:
return f"-{_fmt(-nbytes)}"
if nbytes >= 1 << 30:
return f"{nbytes / (1 << 30):.2f} GiB"
if nbytes >= 1 << 20:
return f"{nbytes / (1 << 20):.2f} MiB"
if nbytes >= 1 << 10:
return f"{nbytes / (1 << 10):.2f} KiB"
return f"{nbytes} B"


BF16 = 2
F32 = 4


def profile(B, T, H, K, V, C):
NT = T // C

print(f"\n{'=' * 72}")
print(f" Memory Profile: B={B}, T={T}, H={H}, K={K}, V={V}, C={C}, NT={NT}")
print(f"{'=' * 72}")

# Tensor sizes
q_sz = _bytes((B, T, H, K), BF16)
k_sz = _bytes((B, T, H, K), BF16)
v_sz = _bytes((B, T, H, V), BF16)
do_sz = _bytes((B, T, H, V), BF16)
inputs_fwd = q_sz + k_sz + v_sz
inputs_bwd = inputs_fwd + do_sz

h_sz = _bytes((B, NT, H, K, V), BF16)
o_sz = _bytes((B, T, H, V), BF16)
g_cumsum_sz = _bytes((B, T, H, K), F32) # note: f32
A_sz = _bytes((B, H, NT, C, C), BF16)
dh_sz = _bytes((B, NT, H, K, V), BF16)
dq_sz = _bytes((B, T, H, K), BF16)
dk_sz = _bytes((B, T, H, K), BF16)
dv_sz = _bytes((B, T, H, V), BF16)
dg_sz = _bytes((B, T, H, K), F32) # f32, before reduce to [H]
g_gamma_sz = _bytes((H,), F32)
vmem_scratch = _bytes((128, 128), F32)

# Pallas 5D outputs
pallas_5d = dq_sz + dk_sz + dv_sz

# ======================================================================
# 1. ACTIVATION RETENTION (fwd -> bwd)
# ======================================================================
print(f"\n{'─' * 72}")
print(f" 1. ACTIVATION RETENTION (saved from fwd for bwd)")
print(f"{'─' * 72}")
print(f" Inputs (q, k, v) are always retained by JAX AD.")
print()

print(f" Non-fused — retained activations:")
print(f" g_cumsum [B,T,H,K] f32 {_fmt(g_cumsum_sz):>12} needed by h_kernel, intra_gk, o_gk bwd")
print(f" h [B,NT,H,K,V] bf16 {_fmt(h_sz):>12} needed by o_gk bwd for inter-chunk grads")
print(f" A [B,H,NT,C,C] bf16 {_fmt(A_sz):>12} needed by o_gk bwd for A @ v")
nf_activations = g_cumsum_sz + h_sz + A_sz
print(f" Total activations {_fmt(nf_activations):>12}")

print()
print(f" Fused — retained activations:")
print(f" h [B,NT,H,K,V] bf16 {_fmt(h_sz):>12} only residual needed")
f_activations = h_sz
print(f" Total activations {_fmt(f_activations):>12}")

act_saved = nf_activations - f_activations
print(f"\n >>> Activation savings: {_fmt(act_saved)}")
print(f" g_cumsum eliminated: {_fmt(g_cumsum_sz)}")
print(f" A eliminated: {_fmt(A_sz)}")

# ======================================================================
# 2. FORWARD PEAK
# ======================================================================
print(f"\n{'─' * 72}")
print(f" 2. FORWARD PEAK HBM")
print(f"{'─' * 72}")

nf_fwd_peak = inputs_fwd + g_cumsum_sz + h_sz + A_sz + o_sz
print(f" Non-fused: inputs + g_cumsum + h + A + o")
print(f" = {_fmt(inputs_fwd)} + {_fmt(g_cumsum_sz)} + {_fmt(h_sz)} + {_fmt(A_sz)} + {_fmt(o_sz)}")
print(f" = {_fmt(nf_fwd_peak)}")

f_fwd_peak = inputs_fwd + h_sz + o_sz
print(f" Fused: inputs + h + o")
print(f" = {_fmt(inputs_fwd)} + {_fmt(h_sz)} + {_fmt(o_sz)}")
print(f" = {_fmt(f_fwd_peak)}")
print(f" Savings: {_fmt(nf_fwd_peak - f_fwd_peak)}")

# ======================================================================
# 3. BACKWARD PEAK
# ======================================================================
print(f"\n{'─' * 72}")
print(f" 3. BACKWARD PEAK HBM (including retained activations)")
print(f"{'─' * 72}")

nf_bwd_peak = inputs_bwd + g_cumsum_sz + h_sz + dh_sz + dq_sz + dk_sz + dv_sz + dg_sz
print(f" Non-fused: inputs + do + g_cumsum + h + dh + dq + dk + dv + dg")
nf_bwd_parts = [
("inputs+do", inputs_bwd),
("g_cumsum", g_cumsum_sz),
("h", h_sz),
("dh", dh_sz),
("dq+dk+dv", dq_sz + dk_sz + dv_sz),
("dg", dg_sz),
]
for name, sz in nf_bwd_parts:
print(f" {name:20s} {_fmt(sz):>12}")
print(f" {'Peak':20s} {_fmt(nf_bwd_peak):>12}")

# Fused bwd: reverse index_maps eliminate all flip copies.
# Only: inputs + do + h + pallas_5d_outputs
f_bwd_peak = inputs_bwd + h_sz + g_gamma_sz + pallas_5d
print(f"\n Fused: inputs + do + h + pallas_outputs (no flips)")
f_bwd_parts = [
("inputs+do", inputs_bwd),
("h (retained)", h_sz),
("g_gamma [H] f32", g_gamma_sz),
("pallas 5D out (dq,dk,dv)", pallas_5d),
]
for name, sz in f_bwd_parts:
print(f" {name:30s} {_fmt(sz):>12}")
print(f" g_cumsum, dh, dg {'ELIMINATED':>12}")
print(f" jnp.flip copies {'ELIMINATED':>12} (reverse index_map)")
print(f" VMEM scratch [BK,BV] f32 {_fmt(vmem_scratch):>12} (on-chip)")
print(f" {'Peak':30s} {_fmt(f_bwd_peak):>12}")

bwd_saved = nf_bwd_peak - f_bwd_peak
print(f"\n >>> Backward savings: {_fmt(bwd_saved)}")

# ======================================================================
# 4. TRAINING TOTAL
# ======================================================================
print(f"\n{'─' * 72}")
print(f" 4. TRAINING PEAK (activations + bwd working set)")
print(f"{'─' * 72}")

nf_train = nf_bwd_peak
f_train = f_bwd_peak

print(f" Non-fused training peak: {_fmt(nf_train)}")
print(f" Fused training peak: {_fmt(f_train)}")
train_saved = nf_train - f_train
print(f" Training savings: {_fmt(train_saved)}")

# ======================================================================
# 5. SUMMARY
# ======================================================================
print(f"\n{'=' * 72}")
print(f" SUMMARY")
print(f"{'=' * 72}")
print(f" {'':32s} {'Non-fused':>12s} {'Fused':>12s} {'Saved':>12s}")
print(f" {'─' * 66}")
print(f" {'Activations (fwd->bwd)':32s} {_fmt(nf_activations):>12s} {_fmt(f_activations):>12s} {_fmt(act_saved):>12s}")
print(f" {'Forward peak':32s} {_fmt(nf_fwd_peak):>12s} {_fmt(f_fwd_peak):>12s} {_fmt(nf_fwd_peak-f_fwd_peak):>12s}")
print(f" {'Backward peak':32s} {_fmt(nf_bwd_peak):>12s} {_fmt(f_bwd_peak):>12s} {_fmt(bwd_saved):>12s}")
print(f" {'Training peak':32s} {_fmt(nf_train):>12s} {_fmt(f_train):>12s} {_fmt(train_saved):>12s}")
print(f" {'─' * 66}")

print(f"\n What fused ELIMINATES from HBM:")
print(f" g_cumsum [B,T,H,K] f32 -{_fmt(g_cumsum_sz):>12}")
print(f" A [B,H,NT,C,C] bf16 (fwd) -{_fmt(A_sz):>12}")
print(f" dh [B,NT,H,K,V] bf16 (bwd) -{_fmt(dh_sz):>12}")
print(f" dg [B,T,H,K] f32 (bwd) -{_fmt(dg_sz):>12}")
print(f" jnp.flip copies (bwd) - 0 B (reverse index_map)")
elim_total = g_cumsum_sz + A_sz + dh_sz + dg_sz
print(f" Total eliminated -{_fmt(elim_total):>12}")
print(f"{'=' * 72}\n")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Memory profiling: fused vs non-fused chunk-GLA"
)
parser.add_argument("--B", type=int, default=2)
parser.add_argument("--T", type=int, default=4096)
parser.add_argument("--H", type=int, default=16)
parser.add_argument("--K", type=int, default=128)
parser.add_argument("--V", type=int, default=128)
parser.add_argument("--C", type=int, default=64, help="Chunk size")
args = parser.parse_args()

assert args.T % args.C == 0, f"T={args.T} must be multiple of C={args.C}"
profile(args.B, args.T, args.H, args.K, args.V, args.C)
114 changes: 114 additions & 0 deletions tests/ops/gla/test_pallas_chunk_fused_bwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Tests for the fused chunked GLA backward kernel (g_gamma mode).

Compares gradients (dq, dk, dv) from the fused backward kernel against
gradients computed via jax.grad on the pure JAX naive recurrent reference.
"""

from __future__ import annotations

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[3]))

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from tests.utils import compare_tensor
from tops.ops.gla.naive import naive_recurrent_gla
from tops.ops.gla.chunk_fused_kernels import (
chunk_bwd_fused_g_gamma,
chunk_fwd_fused_g_gamma,
)


def _make_test_data(B, T, H, K, V, seed=42):
"""Create deterministic (q, k, v, g_gamma, do) for a GLA backward test."""
key = jax.random.PRNGKey(seed)
k1, k2, k3, k4, k5 = jax.random.split(key, 5)
q = jax.random.normal(k1, (B, T, H, K), dtype=jnp.bfloat16)
k_arr = jax.random.normal(k2, (B, T, H, K), dtype=jnp.bfloat16)
v = jax.random.normal(k3, (B, T, H, V), dtype=jnp.bfloat16)
g_gamma = -jnp.abs(jax.random.normal(k4, (H,), dtype=jnp.float32)) * 0.1
do = jax.random.normal(k5, (B, T, H, V), dtype=jnp.bfloat16)
return q, k_arr, v, g_gamma, do


def _naive_grads(q, k, v, g_gamma, do, scale):
"""Compute dq, dk, dv via jax.grad on naive_recurrent_gla."""
B, T, H, K = q.shape
gk = jnp.broadcast_to(g_gamma.reshape(1, 1, -1, 1), (B, T, H, K))

def loss_fn(q_, k_, v_):
o, _ = naive_recurrent_gla(q_, k_, v_, gk, scale=scale)
return (o * do.astype(jnp.float32)).sum()

dq, dk, dv = jax.grad(loss_fn, argnums=(0, 1, 2))(
q.astype(jnp.float32), k.astype(jnp.float32), v.astype(jnp.float32),
)
return dq, dk, dv


@pytest.mark.tpu_only
class TestChunkBwdFused:
"""Tests for chunk_bwd_fused_g_gamma against naive JAX reference gradients."""

def test_fused_bwd_basic(self):
"""Basic fused backward: B=2, T=128, H=4, K=128, V=128, C=64."""
B, T, H, K, V, C = 2, 128, 4, 128, 128, 64
q, k, v, g_gamma, do = _make_test_data(B, T, H, K, V, seed=42)
scale = K**-0.5

# Forward to get h
h, _ = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C)

# Fused backward
dq_fused, dk_fused, dv_fused = chunk_bwd_fused_g_gamma(
q, k, v, h, do, g_gamma, scale, C
)

# Naive reference gradients
dq_ref, dk_ref, dv_ref = _naive_grads(q, k, v, g_gamma, do, scale)

assert compare_tensor(
"dq", dq_ref, dq_fused, atol=5e-2, rtol=5e-2, dtype=np.float32
)
assert compare_tensor(
"dk", dk_ref, dk_fused, atol=5e-2, rtol=5e-2, dtype=np.float32
)
assert compare_tensor(
"dv", dv_ref, dv_fused, atol=5e-2, rtol=5e-2, dtype=np.float32
)

def test_fused_bwd_small(self):
"""Small case: B=1, T=64, H=2, K=128, V=128, C=64."""
B, T, H, K, V, C = 1, 64, 2, 128, 128, 64
q, k, v, g_gamma, do = _make_test_data(B, T, H, K, V, seed=7)
scale = K**-0.5

# Forward to get h
h, _ = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C)

# Fused backward
dq_fused, dk_fused, dv_fused = chunk_bwd_fused_g_gamma(
q, k, v, h, do, g_gamma, scale, C
)

# Naive reference gradients
dq_ref, dk_ref, dv_ref = _naive_grads(q, k, v, g_gamma, do, scale)

assert compare_tensor(
"dq", dq_ref, dq_fused, atol=5e-2, rtol=5e-2, dtype=np.float32
)
assert compare_tensor(
"dk", dk_ref, dk_fused, atol=5e-2, rtol=5e-2, dtype=np.float32
)
assert compare_tensor(
"dv", dv_ref, dv_fused, atol=5e-2, rtol=5e-2, dtype=np.float32
)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading
Loading