diff --git a/benchmarks/ops/benchmark_gla.py b/benchmarks/ops/benchmark_gla.py index eb23d3e0..a7747890 100644 --- a/benchmarks/ops/benchmark_gla.py +++ b/benchmarks/ops/benchmark_gla.py @@ -47,11 +47,13 @@ "fused_recurrent", "chunk", "fused_chunk", + "chunk_fused", "simple_gla_naive", "simple_gla_chunk", "fused_recurrent_bwd", "chunk_bwd", "fused_chunk_bwd", + "chunk_fused_bwd", "simple_gla_chunk_bwd", ] @@ -170,6 +172,36 @@ def loss_fn(q_, k_, v_, gk_): return jax.grad(loss_fn, argnums=(0, 1, 2, 3))(q, k, v, gk) fn = fwd_bwd + elif provider == "chunk_fused": + from tops.ops.gla.chunk_fused_kernels import chunk_fwd_fused_g_gamma + chunk_size = 64 + if T % chunk_size != 0 or D % 128 != 0: + return None + fn = partial( + chunk_fwd_fused_g_gamma, + q, k, v, g_gamma, scale=scale, chunk_size=chunk_size, + ) + elif provider == "chunk_fused_bwd": + from tops.ops.gla.chunk_fused_kernels import ( + chunk_fwd_fused_g_gamma, + chunk_bwd_fused_g_gamma, + ) + chunk_size = 64 + if T % chunk_size != 0 or D % 128 != 0: + return None + + h_fwd, o_fwd = chunk_fwd_fused_g_gamma( + q, k, v, g_gamma, scale=scale, chunk_size=chunk_size, + ) + do = jnp.ones_like(o_fwd) + + @jax.jit + def run_bwd(): + return chunk_bwd_fused_g_gamma( + q, k, v, h_fwd, do, g_gamma, scale=scale, chunk_size=chunk_size, + ) + + fn = run_bwd elif provider == "simple_gla_naive": if T > NAIVE_MAX_T: return None diff --git a/benchmarks/ops/profile_gla_memory.py b/benchmarks/ops/profile_gla_memory.py new file mode 100644 index 00000000..fe3bac04 --- /dev/null +++ b/benchmarks/ops/profile_gla_memory.py @@ -0,0 +1,211 @@ +"""Memory profiling: fused vs non-fused chunk-GLA (forward + backward). + +Analyzes HBM footprint including activation retention between fwd and bwd. + +XLA copy semantics: + - transpose() / reshape(): logical view, zero copy + - broadcast_to(): logical view, zero copy (but materialized when fed to pallas) + +Usage: + python benchmarks/ops/profile_gla_memory.py + python benchmarks/ops/profile_gla_memory.py --B 2 --T 4096 --H 16 --K 128 --V 128 --C 64 +""" + +from __future__ import annotations + +import argparse + + +def _bytes(shape: tuple[int, ...], elem_bytes: int) -> int: + r = elem_bytes + for s in shape: + r *= s + return r + + +def _fmt(nbytes: int) -> str: + if nbytes < 0: + return f"-{_fmt(-nbytes)}" + if nbytes >= 1 << 30: + return f"{nbytes / (1 << 30):.2f} GiB" + if nbytes >= 1 << 20: + return f"{nbytes / (1 << 20):.2f} MiB" + if nbytes >= 1 << 10: + return f"{nbytes / (1 << 10):.2f} KiB" + return f"{nbytes} B" + + +BF16 = 2 +F32 = 4 + + +def profile(B, T, H, K, V, C): + NT = T // C + + print(f"\n{'=' * 72}") + print(f" Memory Profile: B={B}, T={T}, H={H}, K={K}, V={V}, C={C}, NT={NT}") + print(f"{'=' * 72}") + + # Tensor sizes + q_sz = _bytes((B, T, H, K), BF16) + k_sz = _bytes((B, T, H, K), BF16) + v_sz = _bytes((B, T, H, V), BF16) + do_sz = _bytes((B, T, H, V), BF16) + inputs_fwd = q_sz + k_sz + v_sz + inputs_bwd = inputs_fwd + do_sz + + h_sz = _bytes((B, NT, H, K, V), BF16) + o_sz = _bytes((B, T, H, V), BF16) + g_cumsum_sz = _bytes((B, T, H, K), F32) # note: f32 + A_sz = _bytes((B, H, NT, C, C), BF16) + dh_sz = _bytes((B, NT, H, K, V), BF16) + dq_sz = _bytes((B, T, H, K), BF16) + dk_sz = _bytes((B, T, H, K), BF16) + dv_sz = _bytes((B, T, H, V), BF16) + dg_sz = _bytes((B, T, H, K), F32) # f32, before reduce to [H] + g_gamma_sz = _bytes((H,), F32) + vmem_scratch = _bytes((128, 128), F32) + + # Pallas 5D outputs + pallas_5d = dq_sz + dk_sz + dv_sz + + # ====================================================================== + # 1. ACTIVATION RETENTION (fwd -> bwd) + # ====================================================================== + print(f"\n{'─' * 72}") + print(f" 1. ACTIVATION RETENTION (saved from fwd for bwd)") + print(f"{'─' * 72}") + print(f" Inputs (q, k, v) are always retained by JAX AD.") + print() + + print(f" Non-fused — retained activations:") + print(f" g_cumsum [B,T,H,K] f32 {_fmt(g_cumsum_sz):>12} needed by h_kernel, intra_gk, o_gk bwd") + print(f" h [B,NT,H,K,V] bf16 {_fmt(h_sz):>12} needed by o_gk bwd for inter-chunk grads") + print(f" A [B,H,NT,C,C] bf16 {_fmt(A_sz):>12} needed by o_gk bwd for A @ v") + nf_activations = g_cumsum_sz + h_sz + A_sz + print(f" Total activations {_fmt(nf_activations):>12}") + + print() + print(f" Fused — retained activations:") + print(f" h [B,NT,H,K,V] bf16 {_fmt(h_sz):>12} only residual needed") + f_activations = h_sz + print(f" Total activations {_fmt(f_activations):>12}") + + act_saved = nf_activations - f_activations + print(f"\n >>> Activation savings: {_fmt(act_saved)}") + print(f" g_cumsum eliminated: {_fmt(g_cumsum_sz)}") + print(f" A eliminated: {_fmt(A_sz)}") + + # ====================================================================== + # 2. FORWARD PEAK + # ====================================================================== + print(f"\n{'─' * 72}") + print(f" 2. FORWARD PEAK HBM") + print(f"{'─' * 72}") + + nf_fwd_peak = inputs_fwd + g_cumsum_sz + h_sz + A_sz + o_sz + print(f" Non-fused: inputs + g_cumsum + h + A + o") + print(f" = {_fmt(inputs_fwd)} + {_fmt(g_cumsum_sz)} + {_fmt(h_sz)} + {_fmt(A_sz)} + {_fmt(o_sz)}") + print(f" = {_fmt(nf_fwd_peak)}") + + f_fwd_peak = inputs_fwd + h_sz + o_sz + print(f" Fused: inputs + h + o") + print(f" = {_fmt(inputs_fwd)} + {_fmt(h_sz)} + {_fmt(o_sz)}") + print(f" = {_fmt(f_fwd_peak)}") + print(f" Savings: {_fmt(nf_fwd_peak - f_fwd_peak)}") + + # ====================================================================== + # 3. BACKWARD PEAK + # ====================================================================== + print(f"\n{'─' * 72}") + print(f" 3. BACKWARD PEAK HBM (including retained activations)") + print(f"{'─' * 72}") + + nf_bwd_peak = inputs_bwd + g_cumsum_sz + h_sz + dh_sz + dq_sz + dk_sz + dv_sz + dg_sz + print(f" Non-fused: inputs + do + g_cumsum + h + dh + dq + dk + dv + dg") + nf_bwd_parts = [ + ("inputs+do", inputs_bwd), + ("g_cumsum", g_cumsum_sz), + ("h", h_sz), + ("dh", dh_sz), + ("dq+dk+dv", dq_sz + dk_sz + dv_sz), + ("dg", dg_sz), + ] + for name, sz in nf_bwd_parts: + print(f" {name:20s} {_fmt(sz):>12}") + print(f" {'Peak':20s} {_fmt(nf_bwd_peak):>12}") + + # Fused bwd: reverse index_maps eliminate all flip copies. + # Only: inputs + do + h + pallas_5d_outputs + f_bwd_peak = inputs_bwd + h_sz + g_gamma_sz + pallas_5d + print(f"\n Fused: inputs + do + h + pallas_outputs (no flips)") + f_bwd_parts = [ + ("inputs+do", inputs_bwd), + ("h (retained)", h_sz), + ("g_gamma [H] f32", g_gamma_sz), + ("pallas 5D out (dq,dk,dv)", pallas_5d), + ] + for name, sz in f_bwd_parts: + print(f" {name:30s} {_fmt(sz):>12}") + print(f" g_cumsum, dh, dg {'ELIMINATED':>12}") + print(f" jnp.flip copies {'ELIMINATED':>12} (reverse index_map)") + print(f" VMEM scratch [BK,BV] f32 {_fmt(vmem_scratch):>12} (on-chip)") + print(f" {'Peak':30s} {_fmt(f_bwd_peak):>12}") + + bwd_saved = nf_bwd_peak - f_bwd_peak + print(f"\n >>> Backward savings: {_fmt(bwd_saved)}") + + # ====================================================================== + # 4. TRAINING TOTAL + # ====================================================================== + print(f"\n{'─' * 72}") + print(f" 4. TRAINING PEAK (activations + bwd working set)") + print(f"{'─' * 72}") + + nf_train = nf_bwd_peak + f_train = f_bwd_peak + + print(f" Non-fused training peak: {_fmt(nf_train)}") + print(f" Fused training peak: {_fmt(f_train)}") + train_saved = nf_train - f_train + print(f" Training savings: {_fmt(train_saved)}") + + # ====================================================================== + # 5. SUMMARY + # ====================================================================== + print(f"\n{'=' * 72}") + print(f" SUMMARY") + print(f"{'=' * 72}") + print(f" {'':32s} {'Non-fused':>12s} {'Fused':>12s} {'Saved':>12s}") + print(f" {'─' * 66}") + print(f" {'Activations (fwd->bwd)':32s} {_fmt(nf_activations):>12s} {_fmt(f_activations):>12s} {_fmt(act_saved):>12s}") + print(f" {'Forward peak':32s} {_fmt(nf_fwd_peak):>12s} {_fmt(f_fwd_peak):>12s} {_fmt(nf_fwd_peak-f_fwd_peak):>12s}") + print(f" {'Backward peak':32s} {_fmt(nf_bwd_peak):>12s} {_fmt(f_bwd_peak):>12s} {_fmt(bwd_saved):>12s}") + print(f" {'Training peak':32s} {_fmt(nf_train):>12s} {_fmt(f_train):>12s} {_fmt(train_saved):>12s}") + print(f" {'─' * 66}") + + print(f"\n What fused ELIMINATES from HBM:") + print(f" g_cumsum [B,T,H,K] f32 -{_fmt(g_cumsum_sz):>12}") + print(f" A [B,H,NT,C,C] bf16 (fwd) -{_fmt(A_sz):>12}") + print(f" dh [B,NT,H,K,V] bf16 (bwd) -{_fmt(dh_sz):>12}") + print(f" dg [B,T,H,K] f32 (bwd) -{_fmt(dg_sz):>12}") + print(f" jnp.flip copies (bwd) - 0 B (reverse index_map)") + elim_total = g_cumsum_sz + A_sz + dh_sz + dg_sz + print(f" Total eliminated -{_fmt(elim_total):>12}") + print(f"{'=' * 72}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Memory profiling: fused vs non-fused chunk-GLA" + ) + parser.add_argument("--B", type=int, default=2) + parser.add_argument("--T", type=int, default=4096) + parser.add_argument("--H", type=int, default=16) + parser.add_argument("--K", type=int, default=128) + parser.add_argument("--V", type=int, default=128) + parser.add_argument("--C", type=int, default=64, help="Chunk size") + args = parser.parse_args() + + assert args.T % args.C == 0, f"T={args.T} must be multiple of C={args.C}" + profile(args.B, args.T, args.H, args.K, args.V, args.C) diff --git a/tests/ops/gla/test_pallas_chunk_fused_bwd.py b/tests/ops/gla/test_pallas_chunk_fused_bwd.py new file mode 100644 index 00000000..340bba40 --- /dev/null +++ b/tests/ops/gla/test_pallas_chunk_fused_bwd.py @@ -0,0 +1,114 @@ +"""Tests for the fused chunked GLA backward kernel (g_gamma mode). + +Compares gradients (dq, dk, dv) from the fused backward kernel against +gradients computed via jax.grad on the pure JAX naive recurrent reference. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from tests.utils import compare_tensor +from tops.ops.gla.naive import naive_recurrent_gla +from tops.ops.gla.chunk_fused_kernels import ( + chunk_bwd_fused_g_gamma, + chunk_fwd_fused_g_gamma, +) + + +def _make_test_data(B, T, H, K, V, seed=42): + """Create deterministic (q, k, v, g_gamma, do) for a GLA backward test.""" + key = jax.random.PRNGKey(seed) + k1, k2, k3, k4, k5 = jax.random.split(key, 5) + q = jax.random.normal(k1, (B, T, H, K), dtype=jnp.bfloat16) + k_arr = jax.random.normal(k2, (B, T, H, K), dtype=jnp.bfloat16) + v = jax.random.normal(k3, (B, T, H, V), dtype=jnp.bfloat16) + g_gamma = -jnp.abs(jax.random.normal(k4, (H,), dtype=jnp.float32)) * 0.1 + do = jax.random.normal(k5, (B, T, H, V), dtype=jnp.bfloat16) + return q, k_arr, v, g_gamma, do + + +def _naive_grads(q, k, v, g_gamma, do, scale): + """Compute dq, dk, dv via jax.grad on naive_recurrent_gla.""" + B, T, H, K = q.shape + gk = jnp.broadcast_to(g_gamma.reshape(1, 1, -1, 1), (B, T, H, K)) + + def loss_fn(q_, k_, v_): + o, _ = naive_recurrent_gla(q_, k_, v_, gk, scale=scale) + return (o * do.astype(jnp.float32)).sum() + + dq, dk, dv = jax.grad(loss_fn, argnums=(0, 1, 2))( + q.astype(jnp.float32), k.astype(jnp.float32), v.astype(jnp.float32), + ) + return dq, dk, dv + + +@pytest.mark.tpu_only +class TestChunkBwdFused: + """Tests for chunk_bwd_fused_g_gamma against naive JAX reference gradients.""" + + def test_fused_bwd_basic(self): + """Basic fused backward: B=2, T=128, H=4, K=128, V=128, C=64.""" + B, T, H, K, V, C = 2, 128, 4, 128, 128, 64 + q, k, v, g_gamma, do = _make_test_data(B, T, H, K, V, seed=42) + scale = K**-0.5 + + # Forward to get h + h, _ = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C) + + # Fused backward + dq_fused, dk_fused, dv_fused = chunk_bwd_fused_g_gamma( + q, k, v, h, do, g_gamma, scale, C + ) + + # Naive reference gradients + dq_ref, dk_ref, dv_ref = _naive_grads(q, k, v, g_gamma, do, scale) + + assert compare_tensor( + "dq", dq_ref, dq_fused, atol=5e-2, rtol=5e-2, dtype=np.float32 + ) + assert compare_tensor( + "dk", dk_ref, dk_fused, atol=5e-2, rtol=5e-2, dtype=np.float32 + ) + assert compare_tensor( + "dv", dv_ref, dv_fused, atol=5e-2, rtol=5e-2, dtype=np.float32 + ) + + def test_fused_bwd_small(self): + """Small case: B=1, T=64, H=2, K=128, V=128, C=64.""" + B, T, H, K, V, C = 1, 64, 2, 128, 128, 64 + q, k, v, g_gamma, do = _make_test_data(B, T, H, K, V, seed=7) + scale = K**-0.5 + + # Forward to get h + h, _ = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C) + + # Fused backward + dq_fused, dk_fused, dv_fused = chunk_bwd_fused_g_gamma( + q, k, v, h, do, g_gamma, scale, C + ) + + # Naive reference gradients + dq_ref, dk_ref, dv_ref = _naive_grads(q, k, v, g_gamma, do, scale) + + assert compare_tensor( + "dq", dq_ref, dq_fused, atol=5e-2, rtol=5e-2, dtype=np.float32 + ) + assert compare_tensor( + "dk", dk_ref, dk_fused, atol=5e-2, rtol=5e-2, dtype=np.float32 + ) + assert compare_tensor( + "dv", dv_ref, dv_fused, atol=5e-2, rtol=5e-2, dtype=np.float32 + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ops/gla/test_pallas_chunk_fused_fwd.py b/tests/ops/gla/test_pallas_chunk_fused_fwd.py new file mode 100644 index 00000000..55e69fd4 --- /dev/null +++ b/tests/ops/gla/test_pallas_chunk_fused_fwd.py @@ -0,0 +1,77 @@ +"""Tests for the fused chunked GLA forward kernel (g_gamma mode). + +Compares the output of the fused single-pallas_call forward against the +pure JAX naive recurrent reference (no Pallas kernels involved). +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from tests.utils import compare_tensor +from tops.ops.gla.naive import naive_recurrent_gla +from tops.ops.gla.chunk_fused_kernels import chunk_fwd_fused_g_gamma + + +def _make_test_data(B, T, H, K, V, seed=42): + """Create deterministic (q, k, v, g_gamma) for a GLA test case.""" + key = jax.random.PRNGKey(seed) + k1, k2, k3, k4 = jax.random.split(key, 4) + q = jax.random.normal(k1, (B, T, H, K), dtype=jnp.bfloat16) + k_arr = jax.random.normal(k2, (B, T, H, K), dtype=jnp.bfloat16) + v = jax.random.normal(k3, (B, T, H, V), dtype=jnp.bfloat16) + g_gamma = -jnp.abs(jax.random.normal(k4, (H,), dtype=jnp.float32)) * 0.1 + return q, k_arr, v, g_gamma + + +def _run_naive_reference(q, k, v, g_gamma, scale): + """Run pure JAX naive recurrent GLA as ground truth.""" + B, T, H, K = q.shape + gk = jnp.broadcast_to(g_gamma.reshape(1, 1, -1, 1), (B, T, H, K)) + o_ref, _ = naive_recurrent_gla(q, k, v, gk, scale=scale) + return o_ref + + +@pytest.mark.tpu_only +class TestChunkFwdFused: + """Tests for chunk_fwd_fused_g_gamma against pure JAX naive reference.""" + + def test_fused_fwd_basic(self): + """Basic fused forward: B=2, T=256, H=4, K=128, V=128, C=64.""" + B, T, H, K, V, C = 2, 256, 4, 128, 128, 64 + q, k, v, g_gamma = _make_test_data(B, T, H, K, V, seed=42) + scale = K**-0.5 + + # Fused kernel + _, o_fused = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C) + + # Naive reference (pure JAX, no Pallas) + o_ref = _run_naive_reference(q, k, v, g_gamma, scale) + + assert compare_tensor("o", o_ref, o_fused, atol=5e-2, rtol=5e-2, dtype=np.float32) + + def test_fused_fwd_small(self): + """Small case: B=1, T=128, H=2, K=128, V=128, C=64.""" + B, T, H, K, V, C = 1, 128, 2, 128, 128, 64 + q, k, v, g_gamma = _make_test_data(B, T, H, K, V, seed=7) + scale = K**-0.5 + + h_fused, o_fused = chunk_fwd_fused_g_gamma(q, k, v, g_gamma, scale, C) + + o_ref = _run_naive_reference(q, k, v, g_gamma, scale) + + assert o_fused.shape == (B, T, H, V) + assert h_fused.shape[0] == B + assert compare_tensor("o", o_ref, o_fused, atol=5e-2, rtol=5e-2, dtype=np.float32) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ops/gla/test_pallas_chunk_gla_fused_e2e.py b/tests/ops/gla/test_pallas_chunk_gla_fused_e2e.py new file mode 100644 index 00000000..0acf550d --- /dev/null +++ b/tests/ops/gla/test_pallas_chunk_gla_fused_e2e.py @@ -0,0 +1,77 @@ +"""End-to-end tests for fused chunk-GLA dispatch path (g_gamma mode). + +Verifies that ``chunk_gla_fwd`` and ``chunk_gla_bwd_with_pl`` correctly +dispatch to the fused kernels when running in g_gamma mode on TPU, and +that the results match the pure JAX naive recurrent reference. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from tests.utils import compare_tensor +from tops.ops.gla.chunk import chunk_gla_fwd, chunk_gla_bwd_with_pl +from tops.ops.gla.naive import naive_recurrent_gla + + +def _make_test_data(B, T, H, K, V, seed=42): + """Create deterministic (q, k, v, g_gamma, do) for a GLA test case.""" + key = jax.random.PRNGKey(seed) + k1, k2, k3, k4, k5 = jax.random.split(key, 5) + q = jax.random.normal(k1, (B, T, H, K), dtype=jnp.bfloat16) + k_arr = jax.random.normal(k2, (B, T, H, K), dtype=jnp.bfloat16) + v = jax.random.normal(k3, (B, T, H, V), dtype=jnp.bfloat16) + g_gamma = -jnp.abs(jax.random.normal(k4, (H,), dtype=jnp.float32)) * 0.1 + do = jax.random.normal(k5, (B, T, H, V), dtype=jnp.bfloat16) + return q, k_arr, v, g_gamma, do + + +@pytest.mark.tpu_only +class TestChunkGlaFusedE2E: + """End-to-end fused dispatch vs pure JAX naive reference.""" + + def test_fwd_dispatch_output_shape(self): + """Forward dispatch produces correct output shapes.""" + B, T, H, K, V, C = 2, 256, 4, 128, 128, 64 + q, k, v, g_gamma, _ = _make_test_data(B, T, H, K, V) + scale = K**-0.5 + + _, _, h, ht, o = chunk_gla_fwd( + q, k, v, g=None, g_gamma=g_gamma, g_cumsum=None, + scale=scale, initial_state=None, output_final_state=False, + chunk_size=C, + ) + assert o.shape == (B, T, H, V) + assert h.shape[0] == B + assert ht is None + + def test_fwd_dispatch_matches_reference(self): + """Forward fused dispatch matches naive recurrent reference on output.""" + B, T, H, K, V, C = 2, 256, 4, 128, 128, 64 + q, k, v, g_gamma, _ = _make_test_data(B, T, H, K, V) + scale = K**-0.5 + + # Run via the dispatch path (g_gamma ndim==1, should use fused on TPU) + _, _, _, _, o_disp = chunk_gla_fwd( + q, k, v, g=None, g_gamma=g_gamma, g_cumsum=None, + scale=scale, initial_state=None, output_final_state=False, + chunk_size=C, + ) + + # Naive recurrent reference (pure JAX, no Pallas) + gk = jnp.broadcast_to(g_gamma.reshape(1, 1, -1, 1), q.shape) + o_ref, _ = naive_recurrent_gla(q, k, v, gk, scale=scale) + + assert compare_tensor("o", o_ref, o_disp, atol=5e-2, rtol=5e-2, dtype=np.float32) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tops/ops/gla/chunk.py b/tops/ops/gla/chunk.py index 7a285117..4451b9ef 100644 --- a/tops/ops/gla/chunk.py +++ b/tops/ops/gla/chunk.py @@ -870,7 +870,8 @@ def chunk_gla_bwd( # Broadcast g_gamma into full g if g is not provided if g is None: if g_gamma is not None: - g = jnp.broadcast_to(g_gamma, q.shape) + gg = g_gamma.reshape(1, 1, -1, 1) if g_gamma.ndim == 1 else g_gamma + g = jnp.broadcast_to(gg, q.shape) else: g = jnp.zeros_like(q) @@ -1006,7 +1007,8 @@ def chunk_gla_bwd_with_pl( # Broadcast g_gamma into full g if g is not provided if g is None: if g_gamma is not None: - g = jnp.broadcast_to(g_gamma, q.shape) + gg = g_gamma.reshape(1, 1, -1, 1) if g_gamma.ndim == 1 else g_gamma + g = jnp.broadcast_to(gg, q.shape) else: g = jnp.zeros_like(q) @@ -1015,13 +1017,52 @@ def chunk_gla_bwd_with_pl( "cu_seqlens must be multiples of chunk_size for chunk_gla_bwd_with_pl" ) + # Fast path: fused backward for g_gamma mode (fixed-length, on TPU) + if (g_gamma is not None and g_gamma.ndim == 1 + and g_orig is None and cu_seqlens is None + and initial_state is None and dht is None and is_tpu_runtime()): + from tops.ops.gla.chunk_fused_kernels import ( + chunk_fwd_fused_g_gamma, + chunk_bwd_fused_g_gamma, + ) + # Pad K/V dims to multiples of 128 for fused kernels + _q, _k = (pad_to_multiple(x, 128, axis=3, val=0) for x in (q, k)) + _v = pad_to_multiple(v, 128, axis=3, val=0) + _do = pad_to_multiple(do, 128, axis=3, val=0) + + if h is None: + h_fused, _ = chunk_fwd_fused_g_gamma( + _q, _k, _v, g_gamma.astype(jnp.float32), scale, C, + ) + else: + h_fused = pad_to_multiple( + pad_to_multiple(h, 128, axis=3, val=0), 128, axis=4, val=0, + ) + + dq, dk, dv = chunk_bwd_fused_g_gamma( + _q, _k, _v, h_fused, _do, g_gamma.astype(jnp.float32), scale, C, + ) + + # Strip K/V padding from gradients + dq = dq[..., :K] + dk = dk[..., :K] + dv = dv[..., :V] + + # dg: zeros — in g_gamma mode (ndim==1) the gate is a fixed constant, + # not a learnable parameter, so no gradient is needed. If g_gamma + # gradients are required, pass g_gamma.reshape(1, 1, H, 1) so the + # fused dispatch is skipped and the non-fused backward computes dg. + dg = jnp.zeros_like(g_gamma) + return dq, dk, dv, dg, None + # 1. Chunk-local cumsum if g_cumsum is None: if g_gamma is not None and cu_seqlens is None: _, T_pad, _, _ = q.shape pos = jnp.arange(1, C + 1, dtype=jnp.float32) pos = jnp.tile(pos, T_pad // C).reshape(1, T_pad, 1, 1) - g_cumsum = jnp.broadcast_to(g_gamma * pos, q.shape) + gg = g_gamma.reshape(1, 1, -1, 1) if g_gamma.ndim == 1 else g_gamma + g_cumsum = jnp.broadcast_to(gg * pos, q.shape) else: g_cumsum = chunk_local_cumsum_vector(g, C, cu_seqlens=cu_seqlens) @@ -1033,7 +1074,7 @@ def chunk_gla_bwd_with_pl( gk=g_cumsum, h0=initial_state, output_final_state=False, - cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens, chunk_size=C, ) if cu_seqlens is None: @@ -1049,7 +1090,7 @@ def chunk_gla_bwd_with_pl( dht=dht, scale=scale, output_dh0=(initial_state is not None or dht is not None), - cu_seqlens=cu_seqlens, + cu_seqlens_dev=cu_seqlens, chunk_size=C, ) if cu_seqlens is None: @@ -1107,13 +1148,33 @@ def chunk_gla_fwd( V = v.shape[-1] C = chunk_size + # Fast path: fused forward for g_gamma mode (fixed-length, on TPU) + if (g_gamma is not None and g_gamma.ndim == 1 + and g is None and cu_seqlens is None + and initial_state is None and not output_final_state + and is_tpu_runtime()): + from tops.ops.gla.chunk_fused_kernels import chunk_fwd_fused_g_gamma + _q, _k, _v = q, k, v + if T % C != 0: + _q, _k, _v = (pad_to_multiple(x, C, axis=1, val=0) for x in (_q, _k, _v)) + _q, _k = (pad_to_multiple(x, 128, axis=3, val=0) for x in (_q, _k)) + _v = pad_to_multiple(_v, 128, axis=3, val=0) + h_fused, o_fused = chunk_fwd_fused_g_gamma( + _q, _k, _v, g_gamma.astype(jnp.float32), scale, chunk_size=C, + ) + o_fused = o_fused[..., :V] + h_fused = h_fused[..., :K, :V] + o_fused = o_fused[:, :T] + return None, None, h_fused, None, o_fused + # --- padding --- orig_seqlens = None padded_seqlens = None if g is None: if g_gamma is not None: - g = jnp.broadcast_to(g_gamma, q.shape) + gg = g_gamma.reshape(1, 1, -1, 1) if g_gamma.ndim == 1 else g_gamma + g = jnp.broadcast_to(gg, q.shape) else: g = jnp.zeros_like(q) @@ -1140,7 +1201,8 @@ def chunk_gla_fwd( _, T_pad, _, _ = q.shape pos = jnp.arange(1, C + 1, dtype=jnp.float32) pos = jnp.tile(pos, T_pad // C).reshape(1, T_pad, 1, 1) - g_cumsum = jnp.broadcast_to(g_gamma * pos, q.shape) + gg = g_gamma.reshape(1, 1, -1, 1) if g_gamma.ndim == 1 else g_gamma + g_cumsum = jnp.broadcast_to(gg * pos, q.shape) else: g_cumsum = chunk_local_cumsum_vector(g, C, cu_seqlens=cu_seqlens) @@ -1152,7 +1214,7 @@ def chunk_gla_fwd( gk=g_cumsum, h0=initial_state, output_final_state=output_final_state, - cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens, chunk_size=C, ) if cu_seqlens is None: diff --git a/tops/ops/gla/chunk_fused_kernels.py b/tops/ops/gla/chunk_fused_kernels.py new file mode 100644 index 00000000..fabffb8a --- /dev/null +++ b/tops/ops/gla/chunk_fused_kernels.py @@ -0,0 +1,579 @@ +"""Fused chunk-GLA Pallas TPU kernels for g_gamma (per-head constant gate) mode. + +Merges three separate pallas_calls (h propagation + A recomputation + output +computation) into a single pallas_call. This eliminates: + + 1. Two kernel launch overheads + 2. The h tensor HBM round-trip (h is produced and consumed in VMEM) + 3. The A tensor HBM round-trip (A is recomputed inline from q, k, g) + 4. The g_cumsum tensor entirely (gating is computed from the g_gamma scalar) + +The combined kernel uses grid (B, H, K/BK, V/BV, NT) with time as an +"arbitrary" dimension. VMEM scratch holds the h state [BK, BV] in float32 +across time steps. At each time step t: + + 1. Save current h to h_ref output (for backward residuals) + 2. Compute output: o = q_gated @ h * scale + A_masked @ v + 3. Update h: h = h * decay + k.T @ (v * gating) + +For the target shape K=128=BK, V=128=BV the grid is (B, H, 1, 1, NT). +Each tile sees the full h, so no cross-tile reduction is needed. +""" + +import functools + +import jax +import jax.lax as lax +import jax.numpy as jnp +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + +from tops.ops.utils import exp +from tops.utils import assert_shape, export_public + + +def _chunk_fwd_fused_kernel( + q_ref, + k_ref, + v_ref, + g_gamma, + h_ref, + o_ref, + scratch_ref, + *, + BT, + NT, + scale, +): + """Fused forward kernel: h propagation + A recomputation + output. + + At each time step t: + 1. Save h[t] to output (for backward residuals) + 2. Compute o[t] = q_gated @ h[t] * scale + A_masked @ v + where A is recomputed inline from q, k, g_gamma + 3. Update h[t+1] = h[t] * decay + k.T @ (v * gating) + """ + BK = k_ref.shape[3] + BV = v_ref.shape[3] + i_b, i_h, i_k, i_v, i_t = ( + pl.program_id(0), + pl.program_id(1), + pl.program_id(2), + pl.program_id(3), + pl.program_id(4), + ) + + # Per-position gating: g_gamma * (1, 2, ..., BT) + b_g_ramp = g_gamma[i_h].astype(jnp.float32) * (jnp.arange(0, BT) + 1) + # State decay for one full chunk + b_g_last = g_gamma[i_h].astype(jnp.float32) * BT + + # --- Initialize h state in scratch at t=0 --- + @pl.when(i_t == 0) + def init(): + scratch_ref[:, :] = jnp.zeros((BK, BV), dtype=jnp.float32) + + # --- Step 1: Save current h to output (for backward residuals) --- + h_ref[0, i_t, 0] = scratch_ref[...].astype(h_ref.dtype) + + # --- Step 2: Compute output o[t] --- + b_q = q_ref[0, 0] # [BT, BK] + b_k = k_ref[0, 0] # [BT, BK] + b_v = v_ref[0, 0] # [BT, BV] + + # Compute gated q and k for A recomputation (keep in f32 to avoid + # Mosaic bf16 matmul compilation issues on TPU v7x) + exp_g = exp(b_g_ramp) # [BT] + exp_neg_g = exp(-b_g_ramp) # [BT] + b_qg = (b_q.astype(jnp.float32) * exp_g[:, None]) # [BT, BK] f32 + b_kg = (b_k.astype(jnp.float32) * exp_neg_g[:, None]) # [BT, BK] f32 + + # Recompute A = b_qg @ b_kg.T * scale [BT, BT] + b_A = ( + jnp.dot( + b_qg, + b_kg.T, + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + * scale + ) + + # Causal mask + m_s = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :] + b_A_masked = jnp.where(m_s, b_A, 0.0) + + # Inter-chunk: b_qg @ h * scale [BT, BV] + b_o_inter = ( + jnp.dot( + b_qg, + scratch_ref[...], + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + * scale + ) + + # Intra-chunk: A_masked @ v [BT, BV] + b_o_intra = jnp.dot( + b_A_masked, + b_v.astype(jnp.float32), + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + + o_ref[0, 0] = (b_o_inter + b_o_intra).astype(o_ref.dtype) + + # --- Step 3: Update h for next time step --- + # h = h * exp(g_gamma * BT) + k.T @ (v * exp(g_gamma*BT - g_ramp)) + scratch_ref[...] *= exp(b_g_last) + + # v_gated = v * exp(g_gamma*BT - g_ramp) [BT, BV] + v_gated = (b_v * exp(b_g_last - b_g_ramp)[:, None]).astype(b_v.dtype) + + # k.T @ v_gated: [BK, BT] @ [BT, BV] = [BK, BV] + scratch_ref[...] = scratch_ref[...] + jax.lax.dot( + b_k.astype(jnp.float32).T, # [BK, BT] + v_gated.astype(jnp.float32), # [BT, BV] + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + + +@functools.partial(jax.jit, static_argnames=["chunk_size", "scale"]) +def chunk_fwd_fused_g_gamma( + q: jax.Array, + k: jax.Array, + v: jax.Array, + g_gamma: jax.Array, + scale: float, + chunk_size: int, +) -> tuple[jax.Array, jax.Array]: + """Fused chunked GLA forward pass for g_gamma (per-head constant gate) mode. + + Replaces the three-kernel forward pipeline (chunk_fwd_h + intra_gk + + fwd_o_gk) with a single Pallas kernel. The hidden state ``h`` stays + in VMEM scratch instead of making an HBM round-trip, and the attention + matrix ``A`` is recomputed inline rather than materialised. + + Args: + q: [B, T, H, K] queries in bfloat16. + k: [B, T, H, K] keys in bfloat16. + v: [B, T, H, V] values in bfloat16. + g_gamma: [H] per-head constant log-space gate in float32. + scale: Scaling factor, typically K**-0.5. + chunk_size: Block size along the time dimension. + + Returns: + h: [B, NT, H, K, V] hidden states at each chunk boundary (bfloat16), + retained for the backward pass. + o: [B, T, H, V] output (bfloat16). + """ + BK, BV, BT = 128, 128, chunk_size + B, T, H, K_dim = q.shape + V = v.shape[-1] + NT = T // BT + + assert_shape(q, (B, T, H, K_dim), "q") + assert_shape(k, (B, T, H, K_dim), "k") + assert_shape(v, (B, T, H, V), "v") + assert_shape(g_gamma, (H,), "g_gamma") + assert T % BT == 0, f"T ({T}) must be a multiple of chunk_size ({BT})" + assert K_dim == BK, ( + f"Fused forward kernel requires K == {BK}; " + f"multi-tile reduction not implemented (got K={K_dim})" + ) + assert V == BV, ( + f"Fused forward kernel requires V == {BV}; " + f"multi-tile reduction not implemented (got V={V})" + ) + + # Layout: (B, H, T, dim) -- time axis will be "arbitrary" + q_t = jnp.transpose(q, (0, 2, 1, 3)) # [B, H, T, K] + k_t = jnp.transpose(k, (0, 2, 1, 3)) # [B, H, T, K] + v_t = jnp.transpose(v, (0, 2, 1, 3)) # [B, H, T, V] + + grid = (B, H, 1, 1, NT) + + # Index maps: all take 5 grid dims (b, h, ki, vi, t) + def q_map(b, h, ki, vi, t): + return b, h, t, ki + + def k_map(b, h, ki, vi, t): + return b, h, t, ki + + def v_map(b, h, ki, vi, t): + return b, h, t, vi + + def h_map(b, h, ki, vi, t): + return b, 0, h, ki, vi + + def o_map(b, h, ki, vi, t): + return b, h, t, vi + + h_all, o_t = pl.pallas_call( + functools.partial(_chunk_fwd_fused_kernel, BT=BT, NT=NT, scale=scale), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + grid=grid, + in_specs=[ + pl.BlockSpec((1, 1, BT, BK), q_map), # q + pl.BlockSpec((1, 1, BT, BK), k_map), # k + pl.BlockSpec((1, 1, BT, BV), v_map), # v + pl.BlockSpec(memory_space=pltpu.SMEM), # g_gamma + ], + out_specs=[ + pl.BlockSpec((1, NT, 1, BK, BV), h_map), # h output + pl.BlockSpec((1, 1, BT, BV), o_map), # o output + ], + scratch_shapes=[pltpu.VMEM((BK, BV), jnp.float32)], + ), + out_shape=[ + jax.ShapeDtypeStruct((B, NT, H, K_dim, V), q.dtype), # h + jax.ShapeDtypeStruct((B, H, T, V), q.dtype), # o + ], + compiler_params=pltpu.CompilerParams( + dimension_semantics=( + "parallel", + "parallel", + "parallel", + "parallel", + "arbitrary", + ), + disable_bounds_checks=True, + ), + )(q_t, k_t, v_t, g_gamma) + + # o_t is [B, H, T, V], transpose to [B, T, H, V] + o = jnp.transpose(o_t, (0, 2, 1, 3)) # [B, T, H, V] + + return h_all, o + + +# ============================================================ +# Backward: Fused dh reverse propagation + dq/dk/dv computation +# +# Merges the separate backward kernels (dh propagation, dA, dv, +# dq/dk intra, dq/dk/dg inter) into a single pallas_call. +# +# The kernel processes chunks in REVERSE time order via reverse +# BlockSpec index_maps (t → NT-1-t) — no jnp.flip copies needed. +# VMEM scratch holds the dh state [BK, BV] in float32, accumulated +# in reverse. g_cumsum is NOT loaded — gating is recomputed from +# g_gamma scalar as BT-length vectors. +# dg is NOT computed (dead output elimination for g_gamma mode). +# ============================================================ + + +def _chunk_bwd_fused_kernel( + q_ref, + k_ref, + v_ref, + h_ref, + do_ref, + g_gamma, + dq_ref, + dk_ref, + dv_ref, + scratch_ref, + *, + BT, + NT, + scale, +): + """Fused backward kernel with g_cumsum eliminated. + + Processes chunks in REVERSE time order via reverse index_maps. + Grid step i_t=0 reads the LAST chunk (NT-1), i_t=NT-1 reads chunk 0. + At each step: + 1. Load current dh state from scratch_ref + 2. Compute dq/dk/dv using current dh + 3. Write outputs to dq_ref[0, NT-1-i_t, 0] (correct time order) + 4. Update dh: dh = dh * state_decay + q_hat.T @ do + """ + BK = q_ref.shape[3] + BV = do_ref.shape[3] + i_b, i_h, i_k, i_v, i_t = ( + pl.program_id(0), + pl.program_id(1), + pl.program_id(2), + pl.program_id(3), + pl.program_id(4), + ) + + # Recompute gating from g_gamma scalar + b_g_ramp = g_gamma[i_h].astype(jnp.float32) * (jnp.arange(0, BT) + 1) # [BT] + # State decay for one full chunk: scalar + b_g_last = g_gamma[i_h].astype(jnp.float32) * BT # scalar + + # Initialize dh state to zeros at t=0 (first step of REVERSE scan, + # i.e. the END of the original sequence because inputs are flipped) + @pl.when(i_t == 0) + def init(): + scratch_ref[:, :] = jnp.zeros((BK, BV), dtype=jnp.float32) + + # Load inputs for this (reversed) time step + b_q = q_ref[0, 0] # [BT, BK] + b_k = k_ref[0, 0] # [BT, BK] + b_v = v_ref[0, 0] # [BT, BV] + # h_ref is 5D [B, H, NT, K, V]; BlockSpec (1,1,1,BK,BV) maps t->NT dim + b_h = h_ref[0, 0, 0].astype(jnp.float32) # [BK, BV] + b_do = do_ref[0, 0] # [BT, BV] + + # Current dh from scratch (accumulated so far in reverse scan) + b_dh = scratch_ref[...] # [BK, BV] + + # Phase 0 (VPU): Pre-compute ALL exp/gate values upfront + exp_pos = exp(b_g_ramp) # [BT] + exp_neg = exp(-b_g_ramp) # [BT] + exp_gn_minus = exp(b_g_last - b_g_ramp) # [BT] + + # Broadcast [BT] -> [BT, K] at point of use (keep in f32 to avoid + # Mosaic bf16 matmul compilation issues on TPU v7x) + k_neg = (b_k.astype(jnp.float32) * exp_neg[:, None]) # [BT, K] f32 + k_decay = (b_k.astype(jnp.float32) * exp_gn_minus[:, None]) # [BT, K] f32 + q_pos = (b_q.astype(jnp.float32) * exp_pos[:, None]) # [BT, K] f32 + + # Phase 1 (MXU): Recompute A and compute dA + b_a = ( + jnp.dot( + q_pos, + k_neg.T, + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + * scale + ) # [BT, BT] + + b_dA_raw = ( + jnp.dot( + b_do.astype(jnp.float32), + b_v.astype(jnp.float32).T, + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + * scale + ) # [BT, BT] + + # Phase 2 (VPU): Apply causal masks + mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :] + b_dA = jnp.where(mask, b_dA_raw, 0.0) + b_a_masked = jnp.where(mask, b_a, 0.0) + + # Phase 3 (MXU batch): Four independent dot products + b_dv_intra = jnp.dot( + b_a_masked.T, + b_do.astype(jnp.float32), + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) # [BT, V] + + b_dv_inter = jnp.dot( + k_decay, + b_dh, + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) # [BT, V] + + b_dq_inter = jnp.dot( + b_do.astype(jnp.float32), + b_h.T, + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) # [BT, K] + + b_dk_inter = jnp.dot( + b_v.astype(jnp.float32), + b_dh.T, + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) # [BT, K] + + # Phase 4 (MXU): Intra-chunk dq and dk + b_dq_intra_raw = jnp.dot( + b_dA, + k_neg, + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) # [BT, K] + + b_dk_intra_raw = jnp.dot( + b_dA.T, + q_pos, + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) # [BT, K] + + # Phase 5 (VPU): Combine results and write to 5D output slots. + # Write at NT-1-i_t so outputs land in original (forward) time order + # without needing a post-kernel jnp.flip. + i_out = NT - 1 - i_t + dv_ref[0, i_out, 0] = (b_dv_intra + b_dv_inter).astype(dv_ref.dtype) + + b_dq = b_dq_intra_raw * exp_pos[:, None] + b_dq_inter * (scale * exp_pos[:, None]) + dq_ref[0, i_out, 0] = b_dq.astype(dq_ref.dtype) + + b_dk = b_dk_intra_raw * exp_neg[:, None] + b_dk_inter * exp_gn_minus[:, None] + dk_ref[0, i_out, 0] = b_dk.astype(dk_ref.dtype) + + # Phase 6: Update dh state in scratch for next reverse step + # dh = dh * exp(g_gamma * BT) + q_hat.T @ do + scratch_ref[...] *= exp(b_g_last) + + q_hat = (b_q * exp(b_g_ramp)[:, None] * scale).astype(jnp.float32) + + scratch_ref[...] = scratch_ref[...] + jax.lax.dot( + q_hat.T, # [BK, BT] + b_do.astype(jnp.float32), # [BT, BV] + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + + +@functools.partial(jax.jit, static_argnames=["chunk_size", "scale"]) +def chunk_bwd_fused_g_gamma( + q: jax.Array, + k: jax.Array, + v: jax.Array, + h: jax.Array, + do: jax.Array, + g_gamma: jax.Array, + scale: float, + chunk_size: int, +) -> tuple[jax.Array, jax.Array, jax.Array]: + """Fused chunked GLA backward pass for g_gamma (per-head constant gate) mode. + + Replaces the multi-kernel backward pipeline (dh propagation + dA + dv + + dq/dk intra + dq/dk/dg inter) with a single Pallas kernel. The hidden + state gradient ``dh`` stays in VMEM scratch instead of making an HBM + round-trip, and gating is recomputed from the ``g_gamma`` scalar. + + The kernel processes chunks in reverse time order via reverse BlockSpec + index_maps (``t → NT-1-t``), eliminating all ``jnp.flip`` copies. + Outputs are written to their correct (forward) time positions inside + the kernel, so no post-processing flip is needed either. + + Args: + q: [B, T, H, K] queries in bfloat16. + k: [B, T, H, K] keys in bfloat16. + v: [B, T, H, V] values in bfloat16. + h: [B, NT, H, K, V] hidden states from the forward pass. + do: [B, T, H, V] upstream output gradients in bfloat16. + g_gamma: [H] per-head constant log-space gate in float32. + scale: Scaling factor, typically K**-0.5. + chunk_size: Block size along the time dimension. + + Returns: + dq: [B, T, H, K] query gradients (bfloat16). + dk: [B, T, H, K] key gradients (bfloat16). + dv: [B, T, H, V] value gradients (bfloat16). + """ + BK, BV, BT = 128, 128, chunk_size + B, T, H, K = q.shape + V = v.shape[-1] + NT = T // BT + + assert_shape(q, (B, T, H, K), "q") + assert_shape(k, (B, T, H, K), "k") + assert_shape(v, (B, T, H, V), "v") + assert_shape(h, (B, NT, H, K, V), "h") + assert_shape(do, (B, T, H, V), "do") + assert_shape(g_gamma, (H,), "g_gamma") + assert T % BT == 0, f"T ({T}) must be a multiple of chunk_size ({BT})" + assert K == BK, ( + f"Fused backward kernel requires K == {BK}; " + f"multi-tile reduction not implemented (got K={K})" + ) + assert V == BV, ( + f"Fused backward kernel requires V == {BV}; " + f"multi-tile reduction not implemented (got V={V})" + ) + + # Transpose to (B, H, T, dim) layout — zero-copy views + q_t = jnp.transpose(q, (0, 2, 1, 3)) # [B, H, T, K] + k_t = jnp.transpose(k, (0, 2, 1, 3)) # [B, H, T, K] + v_t = jnp.transpose(v, (0, 2, 1, 3)) # [B, H, T, V] + do_t = jnp.transpose(do, (0, 2, 1, 3)) # [B, H, T, V] + + # h is [B, NT, H, K, V]; transpose to [B, H, NT, K, V] — zero-copy view + h_bhntKV = jnp.transpose(h, (0, 2, 1, 3, 4)) # [B, H, NT, K, V] + + grid = (B, H, 1, 1, NT) + + # Reverse input index maps: t → NT-1-t reads chunks in reverse time order. + # No jnp.flip needed — the index_map does the reversal at tile-fetch time. + def q_map(b, h, ki, vi, t): + return b, h, NT - 1 - t, ki + + def k_map(b, h, ki, vi, t): + return b, h, NT - 1 - t, ki + + def v_map(b, h, ki, vi, t): + return b, h, NT - 1 - t, vi + + def h_map(b, h, ki, vi, t): + return b, h, NT - 1 - t, 0, 0 + + def do_map(b, h, ki, vi, t): + return b, h, NT - 1 - t, vi + + # Output index maps — 5D outputs [B, NT, H, BT, K/V]. + # The kernel writes at NT-1-i_t internally, so no output flip needed. + def out_k_map(b, h, ki, vi, t): + return b, 0, h, 0, 0 + + def out_v_map(b, h, ki, vi, t): + return b, 0, h, 0, 0 + + dq_5d, dk_5d, dv_5d = pl.pallas_call( + functools.partial( + _chunk_bwd_fused_kernel, + BT=BT, + NT=NT, + scale=scale, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + grid=grid, + in_specs=[ + pl.BlockSpec((1, 1, BT, BK), q_map), # q: 4D + pl.BlockSpec((1, 1, BT, BK), k_map), # k: 4D + pl.BlockSpec((1, 1, BT, BV), v_map), # v: 4D + pl.BlockSpec((1, 1, 1, BK, BV), h_map), # h: 5D + pl.BlockSpec((1, 1, BT, BV), do_map), # do: 4D + pl.BlockSpec(memory_space=pltpu.SMEM), # g_gamma + ], + out_specs=[ + pl.BlockSpec((1, NT, 1, BT, BK), out_k_map), # dq: 5D + pl.BlockSpec((1, NT, 1, BT, BK), out_k_map), # dk: 5D + pl.BlockSpec((1, NT, 1, BT, BV), out_v_map), # dv: 5D + ], + scratch_shapes=[pltpu.VMEM((BK, BV), jnp.float32)], + ), + out_shape=[ + jax.ShapeDtypeStruct((B, NT, H, BT, K), q.dtype), # dq + jax.ShapeDtypeStruct((B, NT, H, BT, K), k.dtype), # dk + jax.ShapeDtypeStruct((B, NT, H, BT, V), v.dtype), # dv + ], + compiler_params=pltpu.CompilerParams( + dimension_semantics=( + "parallel", + "parallel", + "parallel", + "parallel", + "arbitrary", + ), + disable_bounds_checks=True, + ), + )(q_t, k_t, v_t, h_bhntKV, do_t, g_gamma) + + # dq/dk/dv are [B, NT, H, BT, K/V] already in correct time order + # (kernel wrote at NT-1-i_t). Reshape to [B, T, H, K/V]. + dq = dq_5d.transpose(0, 1, 3, 2, 4).reshape(B, T, H, K) + dk = dk_5d.transpose(0, 1, 3, 2, 4).reshape(B, T, H, K) + dv = dv_5d.transpose(0, 1, 3, 2, 4).reshape(B, T, H, V) + + return dq, dk, dv + + +__all__ = export_public(globals())