diff --git a/benchmarks/ops/benchmark_gla.py b/benchmarks/ops/benchmark_gla.py index 9f7610a8..40ecbb00 100644 --- a/benchmarks/ops/benchmark_gla.py +++ b/benchmarks/ops/benchmark_gla.py @@ -36,6 +36,7 @@ from tops.ops.simple_gla import ( chunk_simple_gla_bwd, chunk_simple_gla_fwd, + chunk_simple_gla_fwd_varlen, simple_gla_naive, fused_chunk_simple_gla_fwd, fused_chunk_simple_gla_bwd, @@ -52,6 +53,7 @@ "fused_chunk", "simple_gla_naive", "simple_gla_chunk", + "simple_gla_chunk_varlen", "fused_simple_gla_chunk", "fused_recurrent_bwd", "chunk_bwd", @@ -191,6 +193,21 @@ def loss_fn(q_, k_, v_, gk_): chunk_simple_gla_fwd, q, k, v, g_gamma=g_gamma, scale=scale, chunk_size=chunk_size, ) + elif provider == "simple_gla_chunk_varlen": + # chunk_simple_gla_fwd requires T % chunk_size == 0 and D % 128 == 0 + chunk_size = 64 + if T % chunk_size != 0 or D % 128 != 0: + return None + _b,_t,_h,_k,_v = q.shape[0],q.shape[1],q.shape[2],q.shape[3],v.shape[3] + + q=jnp.reshape(q, (1, _b*_t, _h, _k)) + k=jnp.reshape(k, (1, _b*_t, _h, _k)) + v=jnp.reshape(v, (1, _b*_t, _h, _v)) + cu_seqlens = jnp.arange(0, _b * _t + 1, _t) + fn = partial( + chunk_simple_gla_fwd_varlen, q, k, v, + g_gamma=g_gamma, scale=scale, chunk_size=chunk_size, cu_seqlens_dev=cu_seqlens, + ) elif provider == "fused_simple_gla_chunk": # chunk_simple_gla_fwd requires T % chunk_size == 0 and D % 128 == 0 chunk_size = 64 diff --git a/tests/ops/simple_gla/test_chunk_simple_gla_varlen_tpu.py b/tests/ops/simple_gla/test_chunk_simple_gla_varlen_tpu.py new file mode 100644 index 00000000..0040a532 --- /dev/null +++ b/tests/ops/simple_gla/test_chunk_simple_gla_varlen_tpu.py @@ -0,0 +1,402 @@ +"""simple_gla varlen forward: Pallas chunk_simple_gla_fwd_varlen vs CPU reference. + +Tests that chunk_simple_gla_fwd_varlen (Pallas, cu_seqlens_dev != None) produces +equivalent results to cpu_chunk_simple_gla_fwd run independently per sequence. + +Constraints for chunk_simple_gla_fwd_varlen: + - B must be 1 (packed varlen layout) + - cu_seqlens_dev values must be multiples of chunk_size + - T must be a multiple of chunk_size + - K, V must be multiples of 128 + - cu_seqlens_cpu must be None + - Only g_gamma gate mode (chunk path does not support per-token g) +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import pytest +import jax +import jax.numpy as jnp + +from tops.ops.simple_gla import chunk_simple_gla_fwd_varlen +from tops.cpu.ops.simple_gla import chunk_simple_gla_fwd as cpu_chunk_simple_gla_fwd +from tests.utils import compare_tensor + +# ============================================================================ +# Test configs +# +# Constraints: +# - B = 1 (varlen requires packed layout with batch=1) +# - T % chunk_size == 0 +# - K % 128 == 0, V % 128 == 0 +# - All cu_seqlens entries must be multiples of chunk_size +# ============================================================================ + +CHUNK_SIZE = 64 + +VARLEN_FWD_CASES = [ + # ── 2 equal segments ── + dict(T=128, H=4, K=128, V=128, cu_seqlens=[0, 64, 128], seed=100, gate="g_gamma"), + # ── 2 unequal segments ── + dict(T=256, H=4, K=128, V=128, cu_seqlens=[0, 64, 256], seed=101, gate="g_gamma"), + # ── 3 segments ── + dict(T=192, H=2, K=128, V=128, cu_seqlens=[0, 64, 128, 192], seed=102, gate="g_gamma"), + # ── single segment (varlen with 1 seq = regular forward) ── + dict(T=128, H=4, K=128, V=128, cu_seqlens=[0, 128], seed=103, gate="g_gamma"), + # ── with h0 ── + dict(T=128, H=4, K=128, V=128, cu_seqlens=[0, 64, 128], seed=110, gate="g_gamma", h0=True), + dict(T=256, H=2, K=128, V=128, cu_seqlens=[0, 64, 256], seed=111, gate="g_gamma", h0=True), + dict(T=192, H=2, K=128, V=128, cu_seqlens=[0, 64, 128, 192], seed=112, gate="g_gamma", h0=True), + # ── K != V ── + dict(T=128, H=2, K=128, V=256, cu_seqlens=[0, 64, 128], seed=120, gate="g_gamma"), + dict(T=128, H=2, K=256, V=128, cu_seqlens=[0, 64, 128], seed=121, gate="g_gamma"), + # ── single head ── + dict(T=128, H=1, K=128, V=128, cu_seqlens=[0, 64, 128], seed=130, gate="g_gamma"), + # ── many heads ── + dict(T=128, H=16, K=128, V=128, cu_seqlens=[0, 64, 128], seed=131, gate="g_gamma"), + # ── longer sequences ── + dict(T=512, H=2, K=128, V=128, cu_seqlens=[0, 128, 384, 512], seed=140, gate="g_gamma"), + dict(T=512, H=4, K=128, V=128, cu_seqlens=[0, 256, 512], seed=141, gate="g_gamma", h0=True), + # ── custom scale ── + dict(T=128, H=4, K=128, V=128, cu_seqlens=[0, 64, 128], seed=150, gate="g_gamma", scale=0.1), + dict(T=128, H=4, K=128, V=128, cu_seqlens=[0, 64, 128], seed=151, gate="g_gamma", scale=0.1, h0=True), + # ── chunk_size=128 ── + dict(T=256, H=2, K=128, V=128, cu_seqlens=[0, 128, 256], seed=160, gate="g_gamma", chunk_size=128), + dict(T=384, H=4, K=128, V=128, cu_seqlens=[0, 128, 256, 384], seed=161, gate="g_gamma", chunk_size=128), + # ── many short segments ── + dict(T=256, H=4, K=128, V=128, cu_seqlens=[0, 64, 128, 192, 256], seed=170, gate="g_gamma"), + dict(T=256, H=2, K=128, V=128, cu_seqlens=[0, 64, 128, 192, 256], seed=171, gate="g_gamma", h0=True), + # ── no gate (g_gamma=None) ── + dict(T=128, H=4, K=128, V=128, cu_seqlens=[0, 64, 128], seed=180, gate="none"), + dict(T=128, H=4, K=128, V=128, cu_seqlens=[0, 64, 128], seed=181, gate="none", h0=True), +] + + +def _varlen_fwd_case_id(c): + n_segs = len(c["cu_seqlens"]) - 1 + parts = [f"T{c['T']}_segs{n_segs}_H{c['H']}_K{c['K']}_V{c['V']}"] + gate = c.get("gate", "none") + if gate != "none": + parts.append(f"gate={gate}") + if c.get("h0"): + parts.append("h0") + if c.get("scale") is not None: + parts.append(f"scale={c['scale']}") + if c.get("chunk_size") is not None: + parts.append(f"C={c['chunk_size']}") + return "-".join(parts) + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _make_varlen_inputs(cfg, *, dtype=jnp.bfloat16): + """Generate random inputs for varlen forward test. + + Args: + cfg: Config dict with T, H, K, V, cu_seqlens, seed, gate, h0, scale. + dtype: Data type for q, k, v, h0. + + Returns: + (q, k, v, g_gamma, h0, cu_seqlens_dev) -- all JAX arrays. + """ + T, H, K, V = cfg["T"], cfg["H"], cfg["K"], cfg["V"] + cu_seqlens = cfg["cu_seqlens"] + N = len(cu_seqlens) - 1 # number of sequences + + key = jax.random.PRNGKey(cfg["seed"]) + keys = jax.random.split(key, 5) + + q = jax.random.normal(keys[0], (1, T, H, K), dtype=dtype) + k = jax.random.normal(keys[1], (1, T, H, K), dtype=dtype) + v = jax.random.normal(keys[2], (1, T, H, V), dtype=dtype) + + gate = cfg.get("gate", "none") + g_gamma = None + if gate == "g_gamma": + g_gamma = -jnp.abs(jax.random.normal(keys[3], (H,), dtype=jnp.float32)) * 0.5 + + h0 = None + if cfg.get("h0"): + h0 = jax.random.normal(keys[4], (N, H, K, V), dtype=dtype) + + cu_seqlens_dev = jnp.array(cu_seqlens, dtype=jnp.int32) + + return q, k, v, g_gamma, h0, cu_seqlens_dev + + +def _run_cpu_ref_per_seq(q, k, v, cu_seqlens, *, g_gamma=None, h0=None, + scale=None, chunk_size=CHUNK_SIZE): + """Run cpu_chunk_simple_gla_fwd independently per sequence and concatenate. + + This is the ground-truth reference: each sub-sequence is processed + independently with its own initial state, guaranteeing no cross-sequence + contamination. + + Args: + q: [1, T, H, K] packed queries. + k: [1, T, H, K] packed keys. + v: [1, T, H, V] packed values. + cu_seqlens: list or array of cumulative sequence lengths. + g_gamma: [H] per-head log-decay, or None. + h0: [N, H, K, V] per-sequence initial state, or None. + scale: query scaling factor. + chunk_size: chunk size. + + Returns: + (o_cat, ht_cat) -- concatenated output and stacked final states. + """ + K_dim = q.shape[-1] + s = scale if scale is not None else K_dim ** -0.5 + N = len(cu_seqlens) - 1 + + outputs = [] + final_states = [] + for n in range(N): + bos, eos = int(cu_seqlens[n]), int(cu_seqlens[n + 1]) + h0_n = h0[n:n + 1] if h0 is not None else None + + o_n, ht_n = cpu_chunk_simple_gla_fwd( + q[:, bos:eos], + k[:, bos:eos], + v[:, bos:eos], + g=None, + g_gamma=g_gamma, + scale=s, + initial_state=h0_n, + output_final_state=True, + chunk_size=chunk_size, + ) + outputs.append(o_n) + if ht_n is not None: + final_states.append(ht_n[0]) + + o_cat = jnp.concatenate(outputs, axis=1) + ht_cat = jnp.stack(final_states, axis=0) if final_states else None + return o_cat, ht_cat + + +def _run_pallas_varlen_fwd(q, k, v, cu_seqlens_dev, *, g_gamma=None, h0=None, + scale=None, chunk_size=CHUNK_SIZE): + """Run chunk_simple_gla_fwd_varlen (Pallas varlen kernel). + + Args: + q: [1, T, H, K] packed queries. + k: [1, T, H, K] packed keys. + v: [1, T, H, V] packed values. + cu_seqlens_dev: [N+1] cumulative sequence lengths (device array). + g_gamma: [H] per-head log-decay, or None. + h0: [N, H, K, V] per-sequence initial state, or None. + scale: query scaling factor. + chunk_size: chunk size. + + Returns: + (o, ht) -- output [1, T, H, V] and final state [N, H, K, V] or None. + """ + o, ht = chunk_simple_gla_fwd_varlen( + q, k, v, + g_gamma=g_gamma, + scale=scale, + h0=h0, + use_ht=True, + cu_seqlens_dev=cu_seqlens_dev, + chunk_size=chunk_size, + ) + return o, ht + + +# ============================================================================ +# Forward: Pallas varlen vs CPU per-sequence reference +# ============================================================================ + + +@pytest.mark.parametrize( + "cfg", VARLEN_FWD_CASES, + ids=[_varlen_fwd_case_id(c) for c in VARLEN_FWD_CASES], +) +def test_varlen_fwd_vs_cpu(cfg): + """chunk_simple_gla_fwd_varlen (Pallas) should match per-sequence CPU reference.""" + T = cfg["T"] + C = cfg.get("chunk_size", CHUNK_SIZE) + NT = T // C + scale = cfg.get("scale", None) + cu_seqlens = cfg["cu_seqlens"] + + # Tolerance: bf16 matmul accumulation errors compound across chunks. + atol = cfg.get("atol", min(5e-2, 2e-2 + 1e-2 * max(NT, 1))) + rtol = cfg.get("rtol", 5e-2) + max_ulp = 4 + + q, k, v, g_gamma, h0, cu_seqlens_dev = _make_varlen_inputs(cfg, dtype=jnp.bfloat16) + + # CPU reference: run each sub-sequence independently + o_ref, ht_ref = _run_cpu_ref_per_seq( + q, k, v, cu_seqlens, + g_gamma=g_gamma, h0=h0, scale=scale, chunk_size=C, + ) + + # Pallas varlen kernel + o_pl, ht_pl = _run_pallas_varlen_fwd( + q, k, v, cu_seqlens_dev, + g_gamma=g_gamma, h0=h0, scale=scale, chunk_size=C, + ) + + assert compare_tensor("output", o_ref, o_pl, atol=atol, rtol=rtol, max_ulp=max_ulp) + + if ht_ref is not None and ht_pl is not None: + ht_atol = max(atol, 5e-2) + assert compare_tensor( + "final_state", ht_ref, ht_pl, atol=ht_atol, rtol=rtol, max_ulp=max_ulp, + ) + + +# ============================================================================ +# Cross-validation: varlen should match non-varlen for single-sequence input +# +# When cu_seqlens = [0, T] (one sequence spanning entire T), the varlen +# kernel should produce identical results to the regular forward path. +# ============================================================================ + +SINGLE_SEQ_CASES = [ + dict(T=128, H=4, K=128, V=128, seed=200, gate="g_gamma"), + dict(T=256, H=2, K=128, V=128, seed=201, gate="g_gamma"), + dict(T=128, H=4, K=128, V=128, seed=202, gate="g_gamma", h0=True), + dict(T=256, H=2, K=128, V=128, seed=203, gate="g_gamma", h0=True), +] + + +def _single_seq_case_id(c): + parts = [f"T{c['T']}_H{c['H']}_K{c['K']}_V{c['V']}"] + if c.get("h0"): + parts.append("h0") + return "-".join(parts) + + +@pytest.mark.parametrize( + "cfg", SINGLE_SEQ_CASES, + ids=[_single_seq_case_id(c) for c in SINGLE_SEQ_CASES], +) +def test_varlen_single_seq_matches_regular(cfg): + """Varlen with cu_seqlens=[0,T] should match non-varlen CPU forward exactly.""" + T, H, K, V = cfg["T"], cfg["H"], cfg["K"], cfg["V"] + C = cfg.get("chunk_size", CHUNK_SIZE) + NT = T // C + + key = jax.random.PRNGKey(cfg["seed"]) + keys = jax.random.split(key, 5) + + q = jax.random.normal(keys[0], (1, T, H, K), dtype=jnp.bfloat16) + k = jax.random.normal(keys[1], (1, T, H, K), dtype=jnp.bfloat16) + v = jax.random.normal(keys[2], (1, T, H, V), dtype=jnp.bfloat16) + g_gamma = -jnp.abs(jax.random.normal(keys[3], (H,), dtype=jnp.float32)) * 0.5 + + h0 = None + if cfg.get("h0"): + h0 = jax.random.normal(keys[4], (1, H, K, V), dtype=jnp.bfloat16) + + s = K ** -0.5 + cu_seqlens_dev = jnp.array([0, T], dtype=jnp.int32) + + # Non-varlen CPU reference (single sequence, no cu_seqlens) + o_ref, ht_ref = cpu_chunk_simple_gla_fwd( + q, k, v, + g=None, g_gamma=g_gamma, scale=s, + initial_state=h0, output_final_state=True, chunk_size=C, + ) + + # Varlen Pallas kernel with cu_seqlens=[0, T] + o_pl, ht_pl = chunk_simple_gla_fwd_varlen( + q, k, v, + g_gamma=g_gamma, scale=s, h0=h0, + use_ht=True, cu_seqlens_dev=cu_seqlens_dev, chunk_size=C, + ) + + atol = min(5e-2, 2e-2 + 1e-2 * max(NT, 1)) + rtol = 5e-2 + max_ulp = 4 + + assert compare_tensor("output", o_ref, o_pl, atol=atol, rtol=rtol, max_ulp=max_ulp) + if ht_ref is not None and ht_pl is not None: + assert compare_tensor( + "final_state", ht_ref, ht_pl, atol=max(atol, 5e-2), rtol=rtol, max_ulp=max_ulp, + ) + + +# ============================================================================ +# Sequence isolation: varlen should prevent cross-sequence information leakage +# +# Verify that the output for each sub-sequence in the packed tensor is +# identical regardless of what other sequences are packed alongside it. +# ============================================================================ + + +def test_varlen_sequence_isolation(): + """Output of each sub-sequence should be independent of other packed sequences.""" + H, K, V, C = 4, 128, 128, 64 + key = jax.random.PRNGKey(42) + keys = jax.random.split(key, 6) + + # Create two independent sequences of length 64 + q1 = jax.random.normal(keys[0], (1, 64, H, K), dtype=jnp.bfloat16) + k1 = jax.random.normal(keys[1], (1, 64, H, K), dtype=jnp.bfloat16) + v1 = jax.random.normal(keys[2], (1, 64, H, V), dtype=jnp.bfloat16) + + q2 = jax.random.normal(keys[3], (1, 64, H, K), dtype=jnp.bfloat16) + k2 = jax.random.normal(keys[4], (1, 64, H, K), dtype=jnp.bfloat16) + v2 = jax.random.normal(keys[5], (1, 64, H, V), dtype=jnp.bfloat16) + + g_gamma = jnp.full((H,), -0.3, dtype=jnp.float32) + s = K ** -0.5 + cu = jnp.array([0, 64, 128], dtype=jnp.int32) + + # Pack [seq1, seq2] and run varlen + q_packed = jnp.concatenate([q1, q2], axis=1) + k_packed = jnp.concatenate([k1, k2], axis=1) + v_packed = jnp.concatenate([v1, v2], axis=1) + + o_packed, _ = chunk_simple_gla_fwd_varlen( + q_packed, k_packed, v_packed, + g_gamma=g_gamma, scale=s, + cu_seqlens_dev=cu, chunk_size=C, + ) + + # Run seq1 alone (non-varlen reference) + o_seq1_ref, _ = cpu_chunk_simple_gla_fwd( + q1, k1, v1, + g=None, g_gamma=g_gamma, scale=s, + initial_state=None, output_final_state=False, chunk_size=C, + ) + + # Now pack [seq1, different_seq2] -- seq1's output should be the same + q3 = jax.random.normal(jax.random.PRNGKey(999), (1, 64, H, K), dtype=jnp.bfloat16) + k3 = jax.random.normal(jax.random.PRNGKey(998), (1, 64, H, K), dtype=jnp.bfloat16) + v3 = jax.random.normal(jax.random.PRNGKey(997), (1, 64, H, V), dtype=jnp.bfloat16) + + q_packed2 = jnp.concatenate([q1, q3], axis=1) + k_packed2 = jnp.concatenate([k1, k3], axis=1) + v_packed2 = jnp.concatenate([v1, v3], axis=1) + + o_packed2, _ = chunk_simple_gla_fwd_varlen( + q_packed2, k_packed2, v_packed2, + g_gamma=g_gamma, scale=s, + cu_seqlens_dev=cu, chunk_size=C, + ) + + # seq1 output should be identical regardless of what seq2 is + assert compare_tensor( + "isolation_packed_vs_ref", o_seq1_ref, o_packed[:, :64], atol=5e-2, rtol=5e-2, max_ulp=4, + ) + assert compare_tensor( + "isolation_cross_pack", o_packed[:, :64], o_packed2[:, :64], atol=1e-6, rtol=1e-6, max_ulp=1, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tops/ops/common/chunk_h.py b/tops/ops/common/chunk_h.py index 5b5c66e6..6ef1624a 100644 --- a/tops/ops/common/chunk_h.py +++ b/tops/ops/common/chunk_h.py @@ -25,7 +25,7 @@ def _chunk_fwd_h_kernel( h0_ref, # [1, 1, BK, BV] gk_ref, # [1, 1, BT, BK] g_ref, # [1, 1, BT, 128] - g_gamma, # [H] + g_gamma_ref, # [H] h_ref, # [1, NS, 1, BK, BV] outputs ht_ref, # [1, 1, BK , BV] scratch_ref, #[BK, BV] @@ -41,8 +41,8 @@ def _chunk_fwd_h_kernel( T = NT * BT i_b, i_h, i_k, i_v, i_t = pl.program_id(0), pl.program_id(1), pl.program_id(2),pl.program_id(3),pl.program_id(4) - if g_gamma is not None: - b_g = g_gamma[i_h].astype(jnp.float32) * (jnp.arange(0, BT) + 1) + if g_gamma_ref is not None: + b_g = g_gamma_ref[i_h].astype(jnp.float32) * (jnp.arange(0, BT) + 1) @pl.when(i_t == 0) def init(): @@ -65,9 +65,9 @@ def store_fn(): scratch_ref[...] *= exp(b_g_scalar_last) # uniform decay v_tile = (v_tile * exp(b_g_scalar_last - b_g_scalar)[:, None]).astype(v_tile.dtype) - if g_gamma is not None: + if g_gamma_ref is not None: # tpu not support scalar bf16 mul - b_g_last = (g_gamma[i_h].astype(jnp.float32) * jnp.minimum(BT, T - i_t * BT)).astype(g_gamma.dtype) + b_g_last = (g_gamma_ref[i_h].astype(jnp.float32) * jnp.minimum(BT, T - i_t * BT)).astype(g_gamma_ref.dtype) scratch_ref[...] *= exp(b_g_last) v_tile = (v_tile * exp(b_g_last - b_g)[:, None]).astype(v_tile.dtype) @@ -883,4 +883,232 @@ def idx_map_state(h, k, v, c): return 0, h, k, v dh0 = dh0.reshape(N, H, K, V) if dh0 is not None else None return dh_all, dh0 + +def _chunk_fwd_h_kernel_varlen( + k_ref, # [1, BT, BK] + v_ref, # [1, BT, BV] + h0_ref, # [N, 1, BK, BV] + gk_ref, # [1, BT, BK] + g_gamma_ref, # [H,] + cu_seqlens_ref, # [num_seq+1] + chunk_to_seq, # [T_sum/BT] + h_ref, # [NS, 1, BK, BV] + ht_ref, # [N, 1, BK , BV] + scratch_ref, # [BK, BV] + *, + BT, + BS, +): + BT, BK = k_ref.shape[1], k_ref.shape[2] + BV = v_ref.shape[2] + + NTS = BS // BT + b_h_start = jnp.zeros((BK, BV), dtype=jnp.float32) + + i_h, i_k, i_v, i_t = pl.program_id(0), pl.program_id(1), pl.program_id(2), pl.program_id(3) + + if g_gamma_ref is not None: + b_g = g_gamma_ref[i_h].astype(jnp.float32) * (jnp.arange(0, BT) + 1) + t0 = i_t * BT + + seq_idx = chunk_to_seq[i_t] + + bos = cu_seqlens_ref[seq_idx] + eos = cu_seqlens_ref[seq_idx + 1] + @pl.when(bos != eos) + def _(): + # reset h state + @pl.when(t0 == bos) + def reset_state(): + if h0_ref is not None: + scratch_ref[...] = h0_ref[seq_idx, 0].astype(scratch_ref.dtype) + else: + scratch_ref[...] = b_h_start + + # store intermediate state + @pl.when(i_t % NTS == 0) + def store_fn(): + s_i = i_t // NTS + h_ref[s_i, 0] = scratch_ref[...].astype(h_ref.dtype) + return None + + k_tile = k_ref[(0, slice(None), slice(None))] # [BT,BK] + v_tile = v_ref[(0, slice(None), slice(None))] # [BT,BV] + + if g_gamma_ref is not None: + # tpu not support scalar bf16 mul + b_g_last = (g_gamma_ref[i_h].astype(jnp.float32) * jnp.minimum(BT, eos - i_t * BT)).astype(g_gamma_ref.dtype) + scratch_ref[...] *= exp(b_g_last) + v_tile = (v_tile * exp(b_g_last - b_g)[:, None]).astype(v_tile.dtype) + + if gk_ref is not None: + gk_tile = gk_ref[(0, slice(None), slice(None))] # BT * BK + g_last = gk_tile[-1, :] + decay = exp(g_last) + scratch_ref[...] = scratch_ref[...] * decay[:, None] # [BK, BV] * [BK,1] + k_tile = (k_tile * exp(g_last[None, :] - gk_tile)).astype(k_tile.dtype) + + # state update + scratch_ref[...] = scratch_ref[...] + jax.lax.dot( + k_tile.astype(jnp.float32).T, + v_tile.astype(jnp.float32), + precision=lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + + @pl.when(t0 + BT >= eos) + def write_final(): + if ht_ref is not None: + ht_ref[seq_idx, 0] = scratch_ref[...].astype(jnp.float32) + + +def check_chunk_fwd(x): + assert x is None, "x should be None." + + +# note: The precision difference between this kernel on the TPU and FLA on the GPU is 5e-2. +@functools.partial( + jax.jit, + static_argnames=[ + "output_final_state", + "chunk_size", + "split_size", + "states_in_fp32", + "interpret", + ], +) +def chunk_fwd_h_kernel_varlen( + k: jax.Array, # [B,T,H,K] + v: jax.Array, # [B,T,H,V] + g: jax.Array | None = None, # [B,T,H] + g_gamma: jax.Array | None = None, # (H,) + gk: jax.Array | None = None, # [B,T,H,K] + gv: jax.Array | None = None, # [B,T,H,V] + h0: jax.Array | None = None, # [N,H,K,V] + output_final_state: bool = False, + cu_seqlens_dev: jax.Array | None = None, + chunk_size: int = 128, + split_size: int | None = None, + states_in_fp32: bool = False, + interpret: bool = False, +): + check_chunk_fwd(g) + check_chunk_fwd(gv) + # todo: tune bk and bv for bast performance + BK = 128 + BV = 128 + B, T, H, K, V = *k.shape, v.shape[-1] + assert K % 128 == 0, "K % 128 must equal to 0." + assert V % 128 == 0, "V % 128 must equal to 0." + assert T % chunk_size == 0, "T mod chunk_size must equal to 0." + + BT = chunk_size + BS = BT if split_size is None else split_size + assert BS % BT == 0, ( + f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}" + ) + # N: the actual number of sequences in the batch with either equal or variable lengths + + T_sum = B * T + chunk_to_seq = _build_chunk_map(cu_seqlens=cu_seqlens_dev, T_sum=T_sum, BT=BT) + + N, NS = ( + len(cu_seqlens_dev) - 1, + T_sum // BS, + ) # split_offsets[-1] # NS number of chunk_size + + k = jnp.reshape(k, (T_sum, H, K)) + v = jnp.reshape(v, (T_sum, H, V)) + + k = jnp.transpose(k, (1, 0, 2)) # (H,B*T,K) + v = jnp.transpose(v, (1, 0, 2)) # (H,B*T,V) + if gk is not None: + gk = jnp.reshape(gk, (T_sum, H, K)) + gk = jnp.transpose(gk, (1, 0, 2)) # (H,B*T,K) + + grid = (H, pl.cdiv(K, BK), pl.cdiv(V, BV), T_sum//BT) + + def k_index_map(head_index, k_index, _, t_index): + return head_index, t_index, k_index + + def gk_index_map(head_index, k_index, _, t_index): + return head_index, t_index, k_index + + def v_index_map(head_index, _, v_index, t_index): + return head_index, t_index, v_index + + def h0_index_map(head_index, k_index, v_index, t_index): + return 0, head_index, k_index, v_index + + def ht_index_map(head_index, k_index, v_index, t_index): + return 0, head_index, k_index, v_index + + def h_index_map(head_index, k_index, v_index, t_index): + return 0, head_index, k_index, v_index + + out_shape = [ + jax.ShapeDtypeStruct( + shape=(NS, H, K, V), dtype=k.dtype if not states_in_fp32 else jnp.float32 + ) + ] + out_specs = [pl.BlockSpec((NS, 1, BK, BV), h_index_map)] + if output_final_state: + out_shape.append(jax.ShapeDtypeStruct(shape=(N, H, K, V), dtype=jnp.float32)) + out_specs.append(pl.BlockSpec((N, 1, BK, BV), ht_index_map)) + else: + out_shape.append(None) + out_specs.append(None) + + in_specs = [ + pl.BlockSpec((1, BT, BK), k_index_map), + pl.BlockSpec((1, BT, BV), v_index_map), + ] + if h0 is not None: + in_specs.append(pl.BlockSpec((N, 1, BK, BV), h0_index_map)) + else: + in_specs.append(None) + if gk is not None: + in_specs.append(pl.BlockSpec((1, BT, BK), gk_index_map)) + else: + in_specs.append(None) + + if g_gamma is not None: + in_specs.append(pl.BlockSpec(memory_space=pltpu.SMEM)) + else: + in_specs.append(None) + + in_specs.append(pl.BlockSpec(memory_space=pltpu.SMEM)) + in_specs.append(pl.BlockSpec(memory_space=pltpu.SMEM)) + scratch = pltpu.VMEM((BK, BV), jnp.float32) + scratch_shapes = [scratch] + kernel = functools.partial( + _chunk_fwd_h_kernel_varlen, + BT=BT, + BS=BS, + ) + h, ht = pl.pallas_call( + kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + grid=grid, + in_specs=in_specs, + out_specs=out_specs, + scratch_shapes = scratch_shapes, + ), + out_shape=out_shape, + interpret=interpret, + compiler_params=pltpu.CompilerParams( + dimension_semantics=( + "parallel", + "parallel", + "parallel", + "arbitrary", + ), + vmem_limit_bytes=128 * 1024 * 1024, + ), + )(k, v, h0, gk, g_gamma, cu_seqlens_dev, chunk_to_seq) + if output_final_state: + return h, ht + return h, None + __all__ = export_public(globals()) diff --git a/tops/ops/common/chunk_o.py b/tops/ops/common/chunk_o.py index a9219b70..5022f2f5 100644 --- a/tops/ops/common/chunk_o.py +++ b/tops/ops/common/chunk_o.py @@ -471,7 +471,6 @@ def chunk_fwd_o( assert_shape_or_none(g_gamma, (H,)) assert T % C == 0, f"Sequence length T={T} must be divisible by chunk_size={C}" assert (cu_seqlens_cpu is None) or (cu_seqlens_cpu % chunk_size == 0).all(), "All sequence lengths must be divisible by chunk_size" - assert cu_seqlens_dev is None or (cu_seqlens_dev % chunk_size == 0).all(), "All device sequence lengths must be divisible by chunk_size" if cu_seqlens_cpu is not None or cu_seqlens_dev is not None: assert B == 1, f"Packed varlen chunk_fwd_o expects B=1, got B={B}" assert scale is not None diff --git a/tops/ops/simple_gla/__init__.py b/tops/ops/simple_gla/__init__.py index 9363ef4c..2a564845 100644 --- a/tops/ops/simple_gla/__init__.py +++ b/tops/ops/simple_gla/__init__.py @@ -6,7 +6,7 @@ import jax from tops.utils import assert_shape, assert_shape_or_none -from .chunk import chunk_simple_gla, chunk_simple_gla_bwd, chunk_simple_gla_fwd +from .chunk import chunk_simple_gla, chunk_simple_gla_bwd, chunk_simple_gla_fwd, chunk_simple_gla_fwd_varlen from .fused_chunk import fused_chunk_simple_gla, fused_chunk_simple_gla_fwd, fused_chunk_simple_gla_bwd from .fused_recurrent import fused_recurrent_simple_gla from .naive import simple_gla_naive @@ -180,10 +180,12 @@ def simple_gla_fwd( mode: SimpleGLAKernelMode = SimpleGLAKernelMode.FUSED_CHUNK ): fn = None - if mode == SimpleGLAKernelMode.CHUNK: + if mode == SimpleGLAKernelMode.CHUNK and cu_seqlens_dev is None: fn = chunk_simple_gla_fwd - elif mode == SimpleGLAKernelMode.FUSED_CHUNK: + elif mode == SimpleGLAKernelMode.FUSED_CHUNK and cu_seqlens_dev is None: fn = fused_chunk_simple_gla_fwd + elif cu_seqlens_dev is not None: + fn = chunk_simple_gla_fwd_varlen else: raise Exception(f"mode {mode} not support") return fn( diff --git a/tops/ops/simple_gla/chunk.py b/tops/ops/simple_gla/chunk.py index 34a7e3c2..1426cb69 100644 --- a/tops/ops/simple_gla/chunk.py +++ b/tops/ops/simple_gla/chunk.py @@ -6,7 +6,7 @@ import jax.experimental.pallas.tpu as pltpu from tops.ops.common.chunk_h import chunk_fwd_h_kernel as chunk_fwd_h -from tops.ops.common.chunk_h import chunk_bwd_dh_kernel as chunk_bwd_dh +from tops.ops.common.chunk_h import chunk_bwd_dh_kernel as chunk_bwd_dh, chunk_fwd_h_kernel_varlen as chunk_fwd_h_varlen from tops.ops.common.chunk_o import chunk_fwd_o, chunk_simple_gla_bwd_o_pl from tops.ops.gla.chunk import chunk_gla_fwd_intra_gk_ref @@ -519,6 +519,63 @@ def chunk_simple_gla_fwd( ) return o, ht +def chunk_simple_gla_fwd_varlen( + q: jax.Array, + k: jax.Array, + v: jax.Array, + *, + g: jax.Array | None = None, + g_gamma: jax.Array | None = None, + scale: float | None = None, + h0: jax.Array | None = None, + use_ht: bool = False, + cu_seqlens_cpu: jax.Array | None = None, + cu_seqlens_dev: jax.Array | None = None, + chunk_size: int = 64, + interpret: bool | None = None, +) -> tuple[jax.Array, jax.Array | None]: + B, T, H, K, V = *q.shape, v.shape[-1] + N = cu_seqlens_dev.shape[0] - 1 if cu_seqlens_dev is not None else B + + assert_shape(q, (B, T, H, K)) + assert_shape(k, (B, T, H, K)) + assert_shape(v, (B, T, H, V)) + assert_shape_or_none(g, (B, T, H)) + assert_shape_or_none(g_gamma, (H,)) + assert_shape_or_none(h0, (N, H, K, V)) + assert T % chunk_size == 0 + assert cu_seqlens_cpu is None, "cu_seqlens_cpu is None." + assert cu_seqlens_dev is not None, "cu_seqlens_dev is not None." + assert (K % 128 == 0) and (V % 128 == 0) + assert B == 1, "B must be 1." + + h, ht = chunk_fwd_h_varlen( + k=k, + v=v, + g=g, + g_gamma=g_gamma, + gk=None, + gv=None, + h0=h0, + output_final_state=use_ht, + states_in_fp32=False, + cu_seqlens_dev=cu_seqlens_dev, + chunk_size=chunk_size, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v, + g=g, + g_gamma=g_gamma, + h=h, + scale=scale, + cu_seqlens_cpu=cu_seqlens_cpu, + cu_seqlens_dev=cu_seqlens_dev, + chunk_size=chunk_size, + ) + return o, ht + def chunk_simple_gla_bwd( q: jax.Array, k: jax.Array,