diff --git a/tops/ops/common/fused_chunk.py b/tops/ops/common/fused_chunk.py index bb2e7cb7..5c7c0e76 100644 --- a/tops/ops/common/fused_chunk.py +++ b/tops/ops/common/fused_chunk.py @@ -4,11 +4,12 @@ (BK, BV) tile of the hidden state stays in VMEM scratch across chunks, eliminating the HBM round-trip for the full ``h: [B, NT, H, K, V]`` tensor. -Grid: (B, H, K // BK, V // BV, NT) - - First 4 dims parallel, NT arbitrary (sequential). +Grid: (B, H, K // BK, V // BV, NT_OUTER) + - First 4 dims parallel, NT_OUTER arbitrary (sequential). - Each grid point handles one (BK, BV) tile of the hidden state. + - NT_OUTER = NT // num_steps; each iteration processes *num_steps* chunks. -At each time step the kernel: +At each time step the kernel loops over *num_steps* sub-chunks: 1. Computes a *partial* output contribution from this K tile: partial = q_k @ h_k * gate + causal(q_k @ k_k^T * gate) @ v 2. Updates this K tile's hidden state: @@ -34,39 +35,40 @@ # --------------------------------------------------------------------------- def _fused_chunk_fwd_kernel( - q_ref, # (1, 1, BT, BK) — K-tiled - k_ref, # (1, 1, BT, BK) — K-tiled - v_ref, # (1, 1, BT, BV) — V-tiled + q_ref, # (1, 1, NUM_STEPS*BT, BK) — K-tiled + k_ref, # (1, 1, NUM_STEPS*BT, BK) — K-tiled + v_ref, # (1, 1, NUM_STEPS*BT, BV) — V-tiled h0_ref, # (1, 1, BK, BV) or None - g_ref, # (1, 1, BT, 128) or None + g_ref, # (1, 1, NUM_STEPS*BT, 128) or None g_gamma_ref, # [H] via SMEM/ANY, or None - o_ref, # (1, 1, 1, BT, BV) — partial output per K tile + o_ref, # (1, 1, 1, NUM_STEPS*BT, BV) — partial output per K tile ht_ref, # (1, 1, BK, BV) — output, or None scratch_ref, # (BK, BV) VMEM float32 *, BT: int, NT: int, + NUM_STEPS: int, + NT_OUTER: int, ): """Fused Pallas kernel: hidden-state propagation + output computation. - Grid: (B, H, NK, NV, NT) - - First 4 dims are parallel; NT is arbitrary (sequential over time). - - NK = K // BK, NV = V // BV. + Grid: (B, H, NK, NV, NT_OUTER) + - First 4 dims are parallel; NT_OUTER is arbitrary (sequential over time). + - NK = K // BK, NV = V // BV, NT_OUTER = NT // NUM_STEPS. - Each grid point processes one (BK, BV) tile of the hidden state - across all time steps. At each time step: - 1. Compute partial output contribution from this K tile. - 2. Update this K tile's hidden state. + Each grid point processes one (BK, BV) tile of the hidden state. + At each iteration the kernel loops over NUM_STEPS sub-chunks, + computing partial output and updating the hidden state for each. The partial outputs are summed across K tiles and scaled by the launcher. Refs (after block-spec indexing): - q_ref / k_ref : (1, 1, BT, BK) — K-tiled - v_ref : (1, 1, BT, BV) — V-tiled + q_ref / k_ref : (1, 1, NUM_STEPS*BT, BK) — K-tiled + v_ref : (1, 1, NUM_STEPS*BT, BV) — V-tiled h0_ref : (1, 1, BK, BV) — initial state tile, or None - g_ref : (1, 1, BT, 128) — scalar gate (broadcast to 4D for TPU alignment), or None + g_ref : (1, 1, NUM_STEPS*BT, 128) — scalar gate, or None g_gamma_ref : [H] — per-head fixed decay, or None - o_ref : (1, 1, 1, BT, BV) — partial output tile + o_ref : (1, 1, 1, NUM_STEPS*BT, BV) — partial output tile ht_ref : (1, 1, BK, BV) — final-state tile, or None scratch_ref : (BK, BV) — running hidden state in VMEM """ @@ -74,13 +76,18 @@ def _fused_chunk_fwd_kernel( BV = v_ref.shape[3] i_t = pl.program_id(4) - # ---- Precompute g_gamma ramp (constant across chunks) ---- + # ---- Precompute constants (invariant across sub-steps) ---- + causal_mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :] + if g_gamma_ref is not None: head_idx = pl.program_id(1) b_gamma = g_gamma_ref[head_idx].astype(jnp.float32) b_g_gamma = b_gamma * (jnp.arange(BT) + 1).astype(jnp.float32) # [BT] + b_g_gamma_last = b_gamma * BT + g_gamma_diff = b_g_gamma[:, None] - b_g_gamma[None, :] + safe_g_gamma_diff = jnp.where(causal_mask, g_gamma_diff, 0.0) - # ---- Initialize hidden state on first chunk ---- + # ---- Initialize hidden state on first iteration ---- @pl.when(i_t == 0) def _init(): if h0_ref is not None: @@ -88,79 +95,75 @@ def _init(): else: scratch_ref[:, :] = jnp.zeros((BK, BV), dtype=jnp.float32) - # ---- Load current chunk ---- - b_q = q_ref[0, 0] # (BT, BK) - b_k = k_ref[0, 0] # (BT, BK) - b_v = v_ref[0, 0] # (BT, BV) - b_h = scratch_ref[...] # (BK, BV) — state *before* this chunk's update - - # ===== Stage 1: Partial output for this K tile ===== - # Partial inter-chunk: q_k @ h_k - partial_o = jnp.dot( - b_q, b_h, - preferred_element_type=jnp.float32, - ) # (BT, BV) - - # Partial intra-chunk attention: q_k @ k_k^T - partial_A = jnp.dot( - b_q, b_k.T, - preferred_element_type=jnp.float32, - ) # (BT, BT) - - # Apply scalar gate g - if g_ref is not None: - b_g = g_ref[0, 0, :, 0].astype(jnp.float32) # (BT,) - partial_o = partial_o * exp(b_g)[:, None] - g_diff = b_g[:, None] - b_g[None, :] - fwd_mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :] - safe_g_diff = jnp.where(fwd_mask, g_diff, 0.0) - partial_A = partial_A * exp(safe_g_diff) - - # Apply per-head fixed decay g_gamma - if g_gamma_ref is not None: - partial_o = partial_o * exp(b_g_gamma)[:, None] - g_gamma_diff = b_g_gamma[:, None] - b_g_gamma[None, :] - fwd_mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :] - safe_g_gamma_diff = jnp.where(fwd_mask, g_gamma_diff, 0.0) - partial_A = partial_A * exp(safe_g_gamma_diff) - - # Causal mask (lower triangular: i >= j) - mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :] - partial_A = jnp.where(mask, partial_A, 0.0) - - # Partial contribution (inter + intra); scale applied after K reduction - partial = partial_o + jnp.dot( - partial_A, b_v.astype(jnp.float32), - precision=jax.lax.Precision.HIGHEST, - preferred_element_type=jnp.float32, - ) # (BT, BV) - o_ref[0, 0, 0] = partial.astype(o_ref.dtype) - - # ===== Stage 2: Update hidden state for this K tile ===== - b_v_upd = b_v - - # Decay state and adjust v for scalar gate g - if g_ref is not None: - b_g_last = b_g[BT - 1] - scratch_ref[...] *= exp(b_g_last) - b_v_upd = (b_v_upd * exp(b_g_last - b_g)[:, None]).astype(b_v_upd.dtype) - - # Decay state and adjust v for g_gamma - if g_gamma_ref is not None: - b_g_gamma_last = b_gamma * BT - scratch_ref[...] *= exp(b_g_gamma_last) - b_v_upd = (b_v_upd * exp(b_g_gamma_last - b_g_gamma)[:, None]).astype(b_v_upd.dtype) - - # State update: h_k += k_k^T @ v_upd - scratch_ref[...] = scratch_ref[...] + jnp.dot( - b_k.astype(jnp.float32).T, - b_v_upd.astype(jnp.float32), - precision=jax.lax.Precision.HIGHEST, - preferred_element_type=jnp.float32, - ) - - # Store final hidden state - @pl.when(i_t == NT - 1) + # ---- Process NUM_STEPS sub-chunks per grid iteration ---- + for step in range(NUM_STEPS): + offset = step * BT + + # Load current sub-chunk + b_q = q_ref[0, 0, pl.ds(offset, BT), :] # (BT, BK) + b_k = k_ref[0, 0, pl.ds(offset, BT), :] # (BT, BK) + b_v = v_ref[0, 0, pl.ds(offset, BT), :] # (BT, BV) + b_h = scratch_ref[...] # (BK, BV) + + # ===== Stage 1: Partial output for this K tile ===== + partial_o = jnp.dot( + b_q, b_h, + preferred_element_type=jnp.float32, + ) # (BT, BV) + + partial_A = jnp.dot( + b_q, b_k.T, + preferred_element_type=jnp.float32, + ) # (BT, BT) + + # Apply scalar gate g + if g_ref is not None: + b_g = g_ref[0, 0, pl.ds(offset, BT), 0].astype(jnp.float32) # (BT,) + partial_o = partial_o * exp(b_g)[:, None] + g_diff = b_g[:, None] - b_g[None, :] + safe_g_diff = jnp.where(causal_mask, g_diff, 0.0) + partial_A = partial_A * exp(safe_g_diff) + + # Apply per-head fixed decay g_gamma + if g_gamma_ref is not None: + partial_o = partial_o * exp(b_g_gamma)[:, None] + partial_A = partial_A * exp(safe_g_gamma_diff) + + # Causal mask (lower triangular: i >= j) + partial_A = jnp.where(causal_mask, partial_A, 0.0) + + # Partial contribution (inter + intra); scale applied after K reduction + partial = partial_o + jnp.dot( + partial_A, b_v.astype(jnp.float32), + precision=jax.lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) # (BT, BV) + o_ref[0, 0, 0, pl.ds(offset, BT), :] = partial.astype(o_ref.dtype) + + # ===== Stage 2: Update hidden state for this K tile ===== + b_v_upd = b_v + + # Decay state and adjust v for scalar gate g + if g_ref is not None: + b_g_last = b_g[BT - 1] + scratch_ref[...] *= exp(b_g_last) + b_v_upd = (b_v_upd * exp(b_g_last - b_g)[:, None]).astype(b_v_upd.dtype) + + # Decay state and adjust v for g_gamma + if g_gamma_ref is not None: + scratch_ref[...] *= exp(b_g_gamma_last) + b_v_upd = (b_v_upd * exp(b_g_gamma_last - b_g_gamma)[:, None]).astype(b_v_upd.dtype) + + # State update: h_k += k_k^T @ v_upd + scratch_ref[...] = scratch_ref[...] + jnp.dot( + b_k.astype(jnp.float32).T, + b_v_upd.astype(jnp.float32), + precision=jax.lax.Precision.HIGHEST, + preferred_element_type=jnp.float32, + ) + + # Store final hidden state at the last outer iteration + @pl.when(i_t == NT_OUTER - 1) def _store_ht(): if ht_ref is not None: ht_ref[0, 0] = scratch_ref[...].astype(ht_ref.dtype) @@ -172,7 +175,7 @@ def _store_ht(): @functools.partial( jax.jit, - static_argnames=("output_final_state", "chunk_size", "interpret"), + static_argnames=("output_final_state", "chunk_size", "num_steps", "interpret"), ) def fused_chunk_fwd_kernel( q: jax.Array, @@ -185,6 +188,7 @@ def fused_chunk_fwd_kernel( scale: float, output_final_state: bool = False, chunk_size: int = 64, + num_steps: int = 4, interpret: bool = False, ) -> tuple[jax.Array, jax.Array | None]: """Pallas launcher for the fused chunk forward kernel. @@ -192,6 +196,9 @@ def fused_chunk_fwd_kernel( Reshapes inputs to (B, H, T, D) layout, launches the fused kernel with K-tiled and V-tiled grid, reduces partial outputs across K tiles, and transposes the result back to (B, T, H, D). + + Each grid iteration processes ``num_steps`` consecutive chunks, reducing + grid overhead by that factor compared to processing one chunk at a time. """ B, T, H, K = q.shape V = v.shape[-1] @@ -202,6 +209,12 @@ def fused_chunk_fwd_kernel( NK = K // BK NV = V // BV + # Auto-adjust num_steps: fall back to the largest valid divisor ≤ requested + while num_steps > 1 and NT % num_steps != 0: + num_steps //= 2 + NT_OUTER = NT // num_steps + WINDOW = num_steps * BT # time-dimension tile size + # Transpose to (B, H, T, D) layout for the kernel _q = q.transpose(0, 2, 1, 3) # (B, H, T, K) _k = k.transpose(0, 2, 1, 3) # (B, H, T, K) @@ -212,7 +225,7 @@ def fused_chunk_fwd_kernel( _g = g.transpose(0, 2, 1) # (B, H, T) _g = jnp.broadcast_to(_g[:, :, :, None], (B, H, T, 128)) # (B, H, T, 128) - grid = (B, H, NK, NV, NT) + grid = (B, H, NK, NV, NT_OUTER) # ---- Index maps ---- def qk_map(b, h, ik, iv, t): @@ -230,17 +243,17 @@ def g_map(b, h, ik, iv, t): def o_map(b, h, ik, iv, t): return (b, h, ik, t, iv) - # ---- Specs ---- + # ---- Specs (time dimension tiles WINDOW = num_steps * BT) ---- smem = pltpu.ANY if interpret else pltpu.SMEM - spec_qk = pl.BlockSpec((1, 1, BT, BK), qk_map) - spec_v = pl.BlockSpec((1, 1, BT, BV), v_map) + spec_qk = pl.BlockSpec((1, 1, WINDOW, BK), qk_map) + spec_v = pl.BlockSpec((1, 1, WINDOW, BV), v_map) spec_h0 = pl.BlockSpec((1, 1, BK, BV), state_map) if h0 is not None else None - spec_g = pl.BlockSpec((1, 1, BT, 128), g_map) if _g is not None else None + spec_g = pl.BlockSpec((1, 1, WINDOW, 128), g_map) if _g is not None else None spec_gamma = pl.BlockSpec(memory_space=smem) if g_gamma is not None else None # o_partial: (B, H, NK, T, V) — partial contribution per K tile, float32 - spec_o = pl.BlockSpec((1, 1, 1, BT, BV), o_map) + spec_o = pl.BlockSpec((1, 1, 1, WINDOW, BV), o_map) spec_ht = pl.BlockSpec((1, 1, BK, BV), state_map) if output_final_state else None # ---- Output shapes ---- @@ -252,7 +265,10 @@ def o_map(b, h, ik, iv, t): # ---- Launch ---- o_partial, ht = pl.pallas_call( - functools.partial(_fused_chunk_fwd_kernel, BT=BT, NT=NT), + functools.partial( + _fused_chunk_fwd_kernel, + BT=BT, NT=NT, NUM_STEPS=num_steps, NT_OUTER=NT_OUTER, + ), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, grid=grid, @@ -294,6 +310,7 @@ def fused_chunk_fwd( scale: float | None = None, use_ht: bool = False, chunk_size: int = 64, + num_steps: int = 4, interpret: bool | None = None, ) -> tuple[jax.Array, jax.Array | None]: """Fused chunk forward: compute output *o* and optionally final state *ht*. @@ -307,25 +324,30 @@ def fused_chunk_fwd( partial output contributions; the launcher reduces across K tiles and applies the attention scale. + Each grid iteration processes ``num_steps`` consecutive chunks via a + for loop inside the kernel, reducing grid overhead by that factor. + Recurrence per chunk c (chunk_size positions): h_c = h_{c-1} * decay(g, g_gamma) + k_c^T @ v_c o_c = (q_c @ h_{c-1} + causal(q_c @ k_c^T) @ v_c) * scale Args: - q: [B, T, H, K] — queries. - k: [B, T, H, K] — keys. - v: [B, T, H, V] — values. - g: [B, T, H] — chunk-local cumsum of scalar gate (optional). - g_gamma: [H] — per-head fixed decay rate (optional). - h0: [B, H, K, V] — initial hidden state (optional). + q: [B, T, H, K] -- queries. + k: [B, T, H, K] -- keys. + v: [B, T, H, V] -- values. + g: [B, T, H] -- chunk-local cumsum of scalar gate (optional). + g_gamma: [H] -- per-head fixed decay rate (optional). + h0: [B, H, K, V] -- initial hidden state (optional). scale: attention scale. Defaults to K ** -0.5. - output_final_state: if True, also return the final hidden state. + use_ht: if True, also return the final hidden state. chunk_size: block size. T must be divisible by chunk_size. + num_steps: number of consecutive chunks processed per grid iteration. + NT must be divisible by num_steps. Default 4 (from PR #74). interpret: Pallas interpret mode. None = auto-detect. Returns: - o: [B, T, H, V] — output tensor. - ht: [B, H, K, V] or None — final hidden state. + o: [B, T, H, V] -- output tensor. + ht: [B, H, K, V] or None -- final hidden state. """ B, T, H, K = q.shape V = v.shape[-1] @@ -345,6 +367,10 @@ def fused_chunk_fwd( assert T % chunk_size == 0, ( f"Sequence length T={T} must be divisible by chunk_size={chunk_size}" ) + NT = T // chunk_size + # Auto-adjust num_steps: fall back to the largest valid divisor ≤ requested + while num_steps > 1 and NT % num_steps != 0: + num_steps //= 2 assert scale is not None # =================== assert kernel requirements done =================== @@ -361,5 +387,6 @@ def fused_chunk_fwd( scale=scale, output_final_state=use_ht, chunk_size=chunk_size, + num_steps=num_steps, interpret=interpret, ) diff --git a/tops/ops/simple_gla/fused_chunk.py b/tops/ops/simple_gla/fused_chunk.py index 27384915..c493e719 100644 --- a/tops/ops/simple_gla/fused_chunk.py +++ b/tops/ops/simple_gla/fused_chunk.py @@ -1,10 +1,14 @@ """Fused chunk kernels for Simple GLA. Forward: delegates to ``fused_chunk_fwd`` (common module). -Backward: fuses ``chunk_bwd_dh`` and ``chunk_simple_gla_bwd_o`` into a -single Pallas kernel so that the full dh tensor [B, NT, H, K, V] never -materialises in HBM. The hidden-state gradient is carried across chunks -in VMEM scratch, analogous to the forward fusion for h. +Backward: split into two passes to reduce register pressure: + Pass 1 (``_bwd_dv_dh_kernel``): computes dv and propagates dh state + backwards (sequential, "arbitrary" time dim). + Pass 2 (``_bwd_dq_dk_kernel``): computes dq and dk using pre-computed h + and dh_states (fully parallel time dim). + +This split avoids materialising 9+ live large matrices simultaneously, +cutting register spills compared to the monolithic backward kernel. """ import functools @@ -55,22 +59,20 @@ def fused_chunk_simple_gla_fwd( # ========================================================================= -# Fused backward: dh propagation + dq/dk/dv in a single Pallas kernel +# Backward Pass 1: dv + dh state propagation (sequential) # ========================================================================= -def _fused_chunk_bwd_kernel( +def _bwd_dv_dh_kernel( q_ref, # (1, 1, BT, K) — reversed chunk order k_ref, # (1, 1, BT, K) — reversed chunk order v_ref, # (1, 1, BT, V) — reversed chunk order - h_ref, # (1, 1, 1, K, V) — from forward, reversed chunk order do_ref, # (1, 1, BT, V) — reversed chunk order g_gamma_ref, # [H] via SMEM/ANY dht_ref, # (1, 1, K, V) or None — terminal state gradient # --- outputs --- - dq_ref, # (1, 1, BT, K) — reversed chunk order - dk_ref, # (1, 1, BT, K) — reversed chunk order dv_ref, # (1, 1, BT, V) — reversed chunk order + dh_states_ref, # (1, 1, 1, K, V) — dh state for this chunk dh0_ref, # (1, 1, K, V) or None — initial state gradient # --- scratch --- scratch_ref, # (K, V) VMEM float32 — carries dh across chunks @@ -79,26 +81,21 @@ def _fused_chunk_bwd_kernel( NT: int, scale: float, ): - """Fused backward kernel: dh propagation + dq/dk/dv computation. + """Pass 1: Compute dv and propagate dh state backwards. - Merges ``chunk_bwd_dh`` and ``chunk_simple_gla_bwd_o`` so that the - full dh tensor [B, NT, H, K, V] never materialises in HBM. + Iterates chunks in reverse time order. For each chunk: + 1. Compute dv = A^T @ do (intra) + k_decay @ dh (inter) + 2. Store dh state for this time step (used by Pass 2) + 3. Update dh: decay + q_hat^T @ do accumulation Grid: (B, H, NT) — B, H parallel, NT arbitrary (backward). - Iteration: i_t = 0 processes the *last* chunk (NT-1), i_t = NT-1 - processes chunk 0. BlockSpec index maps reverse the chunk order. - - At each step: - 1. Load dh from VMEM scratch (gradient flowing from future chunks). - 2. Compute dq, dk, dv using both h (from HBM) and dh (from scratch). - 3. Update dh for the previous chunk and save to scratch. """ K = q_ref.shape[3] V = v_ref.shape[3] - i_t = pl.program_id(2) # 0 → last chunk, NT-1 → first chunk + i_t = pl.program_id(2) head_idx = pl.program_id(1) - # ---- Initialize dh on first backward step ---- + # Initialize dh from dht or zeros @pl.when(i_t == 0) def _init(): if dht_ref is not None: @@ -106,77 +103,37 @@ def _init(): else: scratch_ref[:, :] = jnp.zeros((K, V), dtype=jnp.float32) - # ---- Load data ---- - b_q = q_ref[0, 0] # (BT, K) - b_k = k_ref[0, 0] # (BT, K) - b_v = v_ref[0, 0] # (BT, V) - b_h = h_ref[0, 0, 0].astype(jnp.float32) # (K, V) - b_do = do_ref[0, 0] # (BT, V) - b_dh = scratch_ref[...].astype(jnp.float32) # (K, V) + # Load tile data + b_q = q_ref[0, 0] + b_k = k_ref[0, 0] + b_v = v_ref[0, 0] + b_do = do_ref[0, 0] + b_dh = scratch_ref[...].astype(jnp.float32) - # ---- Per-position decay from g_gamma ---- + # Compute gating factors b_gamma = g_gamma_ref[head_idx].astype(jnp.float32) - b_g = b_gamma * (jnp.arange(BT) + 1).astype(jnp.float32) # [BT] - b_gn = b_g[BT - 1] # g_gamma * BT + b_g = b_gamma * (jnp.arange(BT) + 1).astype(jnp.float32) + b_gn = b_g[BT - 1] - # ---- Recompute A in-kernel ---- + # Build causal attention matrix with decay pos = (jnp.arange(BT) + 1).astype(jnp.float32) mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :] safe_diff = jnp.where(mask, b_gamma * (pos[:, None] - pos[None, :]), 0.0) decay = jnp.exp(safe_diff) - b_a = jnp.dot( - b_q, b_k.T, - preferred_element_type=jnp.float32, - ) * scale * decay - - # ---- dA = (do @ v^T) * scale, lower-triangular ---- - b_dA = jnp.dot( - b_do, b_v.T, - preferred_element_type=jnp.float32, - ) * scale - b_dA = jnp.where(mask, b_dA, 0.0) - - # ---- dv = A^T @ do + k_decay @ dh ---- + b_a = jnp.dot(b_q, b_k.T, preferred_element_type=jnp.float32) * scale * decay b_a_masked = jnp.where(mask, b_a, 0.0).astype(b_do.dtype) - b_dv_intra = jnp.dot( - b_a_masked.T, b_do, - preferred_element_type=jnp.float32, - ) + + # dv = A^T @ do (intra-chunk) + k_decay @ dh (inter-chunk) + b_dv_intra = jnp.dot(b_a_masked.T, b_do, preferred_element_type=jnp.float32) k_decay = (b_k * jnp.exp(b_gn - b_g)[:, None]).astype(b_k.dtype) - b_dv_inter = jnp.dot( - k_decay, b_dh.astype(b_k.dtype), - preferred_element_type=jnp.float32, - ) + b_dv_inter = jnp.dot(k_decay, b_dh.astype(b_k.dtype), preferred_element_type=jnp.float32) dv_ref[0, 0] = (b_dv_intra + b_dv_inter).astype(dv_ref.dtype) - # ---- Gated dA: dA * exp(g_i - g_j) ---- - safe_g_diff = jnp.where(mask, b_g[:, None] - b_g[None, :], 0.0) - b_dA_gated = b_dA * jnp.exp(safe_g_diff) + # Store dh state BEFORE updating (Pass 2 needs dh at this time step) + dh_states_ref[0, 0, 0] = scratch_ref[...].astype(dh_states_ref.dtype) - # ---- dq = gated_dA @ k + do @ h^T * scale * exp(g) ---- - b_dq_intra = jnp.dot( - b_dA_gated.astype(b_k.dtype), b_k, - preferred_element_type=jnp.float32, - ) - b_dq_inter = jnp.dot( - b_do, b_h.astype(b_do.dtype).T, - preferred_element_type=jnp.float32, - ) * (scale * jnp.exp(b_g)[:, None]) - dq_ref[0, 0] = (b_dq_intra + b_dq_inter).astype(dq_ref.dtype) - - # ---- dk = gated_dA^T @ q + v @ dh^T * exp(g_last - g) ---- - b_dk_intra = jnp.dot( - b_dA_gated.T.astype(b_q.dtype), b_q, - preferred_element_type=jnp.float32, - ) - b_dk_inter = jnp.dot( - b_v, b_dh.astype(b_v.dtype).T, - preferred_element_type=jnp.float32, - ) * jnp.exp(b_gn - b_g)[:, None] - dk_ref[0, 0] = (b_dk_intra + b_dk_inter).astype(dk_ref.dtype) - - # ---- Update dh for previous chunk ---- - # Recurrence: dh_n = dh_{n+1} * exp(g_gamma * BT) + (q*scale*exp(g))^T @ do + # Update dh for next (earlier) time step: + # dh_{t-1} = dh_t * exp(gamma * BT) + q_hat^T @ do b_dh = b_dh * jnp.exp(b_gn) b_q_hat = (b_q * (scale * jnp.exp(b_g)[:, None])).astype(jnp.float32) b_dh = b_dh + jnp.dot( @@ -186,13 +143,77 @@ def _init(): ) scratch_ref[...] = b_dh - # ---- Store dh0 at last backward step (chunk 0) ---- + # Store dh0 at last backward step (chunk 0) @pl.when(i_t == NT - 1) def _store_dh0(): if dh0_ref is not None: dh0_ref[0, 0] = scratch_ref[...].astype(dh0_ref.dtype) +# ========================================================================= +# Backward Pass 2: dq + dk (parallel, dh_states fully materialized) +# ========================================================================= + + +def _bwd_dq_dk_kernel( + q_ref, # (1, 1, BT, K) — reversed chunk order + k_ref, # (1, 1, BT, K) — reversed chunk order + v_ref, # (1, 1, BT, V) — reversed chunk order + h_ref, # (1, 1, 1, K, V) — from forward, reversed chunk order + do_ref, # (1, 1, BT, V) — reversed chunk order + g_gamma_ref, # [H] via SMEM/ANY + dh_states_ref, # (1, 1, 1, K, V) — from Pass 1, reversed chunk order + # --- outputs --- + dq_ref, # (1, 1, BT, K) — reversed chunk order + dk_ref, # (1, 1, BT, K) — reversed chunk order + *, + BT: int, + NT: int, + scale: float, +): + """Pass 2: Compute dq and dk using pre-computed h and dh_states. + + No sequential state dependency — dh_states are fully materialized from + Pass 1, so this kernel uses parallel semantics on the time dimension. + + Grid: (B, H, NT) — all dimensions parallel. + """ + K = q_ref.shape[3] + V = v_ref.shape[3] + head_idx = pl.program_id(1) + + # Load tile data + b_q = q_ref[0, 0] + b_k = k_ref[0, 0] + b_v = v_ref[0, 0] + b_h = h_ref[0, 0, 0].astype(jnp.float32) + b_do = do_ref[0, 0] + b_dh = dh_states_ref[0, 0, 0].astype(jnp.float32) + + # Compute gating factors + b_gamma = g_gamma_ref[head_idx].astype(jnp.float32) + b_g = b_gamma * (jnp.arange(BT) + 1).astype(jnp.float32) + b_gn = b_g[BT - 1] + + # Build causal mask and gated dA + mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :] + b_dA = jnp.dot(b_do, b_v.T, preferred_element_type=jnp.float32) * scale + b_dA = jnp.where(mask, b_dA, 0.0) + + safe_g_diff = jnp.where(mask, b_g[:, None] - b_g[None, :], 0.0) + b_dA_gated = b_dA * jnp.exp(safe_g_diff) + + # dq = dA_gated @ k (intra) + do @ h^T * scale * exp(g) (inter) + b_dq_intra = jnp.dot(b_dA_gated.astype(b_k.dtype), b_k, preferred_element_type=jnp.float32) + b_dq_inter = jnp.dot(b_do, b_h.astype(b_do.dtype).T, preferred_element_type=jnp.float32) * (scale * jnp.exp(b_g)[:, None]) + dq_ref[0, 0] = (b_dq_intra + b_dq_inter).astype(dq_ref.dtype) + + # dk = dA_gated^T @ q (intra) + v @ dh^T * exp(gn - g) (inter) + b_dk_intra = jnp.dot(b_dA_gated.T.astype(b_q.dtype), b_q, preferred_element_type=jnp.float32) + b_dk_inter = jnp.dot(b_v, b_dh.astype(b_v.dtype).T, preferred_element_type=jnp.float32) * jnp.exp(b_gn - b_g)[:, None] + dk_ref[0, 0] = (b_dk_intra + b_dk_inter).astype(dk_ref.dtype) + + # ========================================================================= # Pallas launcher # ========================================================================= @@ -216,10 +237,10 @@ def _fused_chunk_bwd_launcher( chunk_size: int, interpret: bool, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array | None]: - """Pallas launcher for the fused backward kernel. + """Pallas launcher for the split backward kernel (2 passes). - Transposes inputs to (B, H, T, D) layout, launches the kernel with - reversed chunk order, and transposes results back to (B, T, H, D). + Pass 1: dv + dh propagation (sequential over time). + Pass 2: dq + dk (parallel, dh_states fully materialized from Pass 1). """ B, T, H, K = q.shape V = v.shape[-1] @@ -250,46 +271,86 @@ def state_map(b, h, t): smem = pltpu.ANY if interpret else pltpu.SMEM - in_specs = [ - pl.BlockSpec((1, 1, BT, K), rev_qk_map), # q - pl.BlockSpec((1, 1, BT, K), rev_qk_map), # k - pl.BlockSpec((1, 1, BT, V), rev_v_map), # v - pl.BlockSpec((1, 1, 1, K, V), rev_h_map), # h - pl.BlockSpec((1, 1, BT, V), rev_v_map), # do - pl.BlockSpec(memory_space=smem), # g_gamma + # ===================================================================== + # Pass 1: dv + dh propagation (sequential, "arbitrary" time dim) + # ===================================================================== + pass1_in_specs = [ + pl.BlockSpec((1, 1, BT, K), rev_qk_map), # q + pl.BlockSpec((1, 1, BT, K), rev_qk_map), # k + pl.BlockSpec((1, 1, BT, V), rev_v_map), # v + pl.BlockSpec((1, 1, BT, V), rev_v_map), # do + pl.BlockSpec(memory_space=smem), # g_gamma pl.BlockSpec((1, 1, K, V), state_map) if dht is not None else None, # dht ] - out_specs = [ - pl.BlockSpec((1, 1, BT, K), rev_qk_map), # dq - pl.BlockSpec((1, 1, BT, K), rev_qk_map), # dk - pl.BlockSpec((1, 1, BT, V), rev_v_map), # dv - pl.BlockSpec((1, 1, K, V), state_map) if output_dh0 else None, # dh0 + pass1_out_specs = [ + pl.BlockSpec((1, 1, BT, V), rev_v_map), # dv + pl.BlockSpec((1, 1, 1, K, V), rev_h_map), # dh_states + pl.BlockSpec((1, 1, K, V), state_map) if output_dh0 else None, # dh0 ] - out_shapes = [ - jax.ShapeDtypeStruct((B, H, T, K), q.dtype), # dq - jax.ShapeDtypeStruct((B, H, T, K), k.dtype), # dk - jax.ShapeDtypeStruct((B, H, T, V), v.dtype), # dv - jax.ShapeDtypeStruct((B, H, K, V), jnp.float32) # dh0 + pass1_out_shapes = [ + jax.ShapeDtypeStruct((B, H, T, V), v.dtype), # dv + jax.ShapeDtypeStruct((B, H, NT, K, V), jnp.float32), # dh_states + jax.ShapeDtypeStruct((B, H, K, V), jnp.float32) # dh0 if output_dh0 else None, ] - dq, dk, dv, dh0 = pl.pallas_call( - functools.partial(_fused_chunk_bwd_kernel, BT=BT, NT=NT, scale=scale), + dv, dh_states, dh0 = pl.pallas_call( + functools.partial(_bwd_dv_dh_kernel, BT=BT, NT=NT, scale=scale), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, grid=grid, - in_specs=in_specs, - out_specs=out_specs, + in_specs=pass1_in_specs, + out_specs=pass1_out_specs, scratch_shapes=[pltpu.VMEM((K, V), jnp.float32)], ), - out_shape=out_shapes, + out_shape=pass1_out_shapes, compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary"), ), interpret=interpret, - )(_q, _k, _v, _h, _do, g_gamma, dht) + )(_q, _k, _v, _do, g_gamma, dht) + + # ===================================================================== + # Pass 2: dq + dk (parallel, dh_states fully materialized) + # ===================================================================== + pass2_in_specs = [ + pl.BlockSpec((1, 1, BT, K), rev_qk_map), # q + pl.BlockSpec((1, 1, BT, K), rev_qk_map), # k + pl.BlockSpec((1, 1, BT, V), rev_v_map), # v + pl.BlockSpec((1, 1, 1, K, V), rev_h_map), # h (pre-computed fwd states) + pl.BlockSpec((1, 1, BT, V), rev_v_map), # do + pl.BlockSpec(memory_space=smem), # g_gamma + pl.BlockSpec((1, 1, 1, K, V), rev_h_map), # dh_states from Pass 1 + ] + + pass2_out_specs = [ + pl.BlockSpec((1, 1, BT, K), rev_qk_map), # dq + pl.BlockSpec((1, 1, BT, K), rev_qk_map), # dk + ] + + pass2_out_shapes = [ + jax.ShapeDtypeStruct((B, H, T, K), q.dtype), # dq + jax.ShapeDtypeStruct((B, H, T, K), k.dtype), # dk + ] + + dq, dk = pl.pallas_call( + functools.partial(_bwd_dq_dk_kernel, BT=BT, NT=NT, scale=scale), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + grid=grid, + in_specs=pass2_in_specs, + out_specs=pass2_out_specs, + scratch_shapes=[], + ), + out_shape=pass2_out_shapes, + compiler_params=pltpu.CompilerParams( + # Time dimension is now parallel since dh_states are materialized + dimension_semantics=("parallel", "parallel", "parallel"), + ), + interpret=interpret, + )(_q, _k, _v, _h, _do, g_gamma, dh_states) # Transpose back: (B, H, T, D) -> (B, T, H, D) dq = dq.transpose(0, 2, 1, 3) @@ -322,10 +383,11 @@ def fused_chunk_simple_gla_bwd( ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array | None]: """Fused chunk backward for simple GLA (g_gamma only). - Merges hidden-state gradient propagation (``chunk_bwd_dh``) and per-chunk - gradient computation (``chunk_simple_gla_bwd_o``) into a single Pallas - kernel. The dh tensor [B, NT, H, K, V] stays in VMEM scratch, avoiding - HBM materialisation. + Split into two passes to reduce register pressure: + Pass 1: dv + dh propagation (sequential, "arbitrary" time dim). + Carries dh in VMEM scratch, outputs dv and materialized dh_states. + Pass 2: dq + dk (parallel time dim). + Uses pre-computed h (from forward) and dh_states (from Pass 1). The forward hidden states h are recomputed via ``chunk_fwd_h``. @@ -383,7 +445,7 @@ def fused_chunk_simple_gla_bwd( states_in_fp32=True, ) - # 2. Fused dh + dq/dk/dv backward kernel + # 2. Split backward: Pass 1 (dv + dh) then Pass 2 (dq + dk) dq, dk, dv, dh0 = _fused_chunk_bwd_launcher( q, k, v, h, do, g_gamma, dht, scale=scale,