Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 138 additions & 111 deletions tops/ops/common/fused_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
(BK, BV) tile of the hidden state stays in VMEM scratch across chunks,
eliminating the HBM round-trip for the full ``h: [B, NT, H, K, V]`` tensor.

Grid: (B, H, K // BK, V // BV, NT)
- First 4 dims parallel, NT arbitrary (sequential).
Grid: (B, H, K // BK, V // BV, NT_OUTER)
- First 4 dims parallel, NT_OUTER arbitrary (sequential).
- Each grid point handles one (BK, BV) tile of the hidden state.
- NT_OUTER = NT // num_steps; each iteration processes *num_steps* chunks.

At each time step the kernel:
At each time step the kernel loops over *num_steps* sub-chunks:
1. Computes a *partial* output contribution from this K tile:
partial = q_k @ h_k * gate + causal(q_k @ k_k^T * gate) @ v
2. Updates this K tile's hidden state:
Expand All @@ -34,133 +35,135 @@
# ---------------------------------------------------------------------------

def _fused_chunk_fwd_kernel(
q_ref, # (1, 1, BT, BK) — K-tiled
k_ref, # (1, 1, BT, BK) — K-tiled
v_ref, # (1, 1, BT, BV) — V-tiled
q_ref, # (1, 1, NUM_STEPS*BT, BK) — K-tiled
k_ref, # (1, 1, NUM_STEPS*BT, BK) — K-tiled
v_ref, # (1, 1, NUM_STEPS*BT, BV) — V-tiled
h0_ref, # (1, 1, BK, BV) or None
g_ref, # (1, 1, BT, 128) or None
g_ref, # (1, 1, NUM_STEPS*BT, 128) or None
g_gamma_ref, # [H] via SMEM/ANY, or None
o_ref, # (1, 1, 1, BT, BV) — partial output per K tile
o_ref, # (1, 1, 1, NUM_STEPS*BT, BV) — partial output per K tile
ht_ref, # (1, 1, BK, BV) — output, or None
scratch_ref, # (BK, BV) VMEM float32
*,
BT: int,
NT: int,
NUM_STEPS: int,
NT_OUTER: int,
):
"""Fused Pallas kernel: hidden-state propagation + output computation.

Grid: (B, H, NK, NV, NT)
- First 4 dims are parallel; NT is arbitrary (sequential over time).
- NK = K // BK, NV = V // BV.
Grid: (B, H, NK, NV, NT_OUTER)
- First 4 dims are parallel; NT_OUTER is arbitrary (sequential over time).
- NK = K // BK, NV = V // BV, NT_OUTER = NT // NUM_STEPS.

Each grid point processes one (BK, BV) tile of the hidden state
across all time steps. At each time step:
1. Compute partial output contribution from this K tile.
2. Update this K tile's hidden state.
Each grid point processes one (BK, BV) tile of the hidden state.
At each iteration the kernel loops over NUM_STEPS sub-chunks,
computing partial output and updating the hidden state for each.

The partial outputs are summed across K tiles and scaled by the launcher.

Refs (after block-spec indexing):
q_ref / k_ref : (1, 1, BT, BK) — K-tiled
v_ref : (1, 1, BT, BV) — V-tiled
q_ref / k_ref : (1, 1, NUM_STEPS*BT, BK) — K-tiled
v_ref : (1, 1, NUM_STEPS*BT, BV) — V-tiled
h0_ref : (1, 1, BK, BV) — initial state tile, or None
g_ref : (1, 1, BT, 128) — scalar gate (broadcast to 4D for TPU alignment), or None
g_ref : (1, 1, NUM_STEPS*BT, 128) — scalar gate, or None
g_gamma_ref : [H] — per-head fixed decay, or None
o_ref : (1, 1, 1, BT, BV) — partial output tile
o_ref : (1, 1, 1, NUM_STEPS*BT, BV) — partial output tile
ht_ref : (1, 1, BK, BV) — final-state tile, or None
scratch_ref : (BK, BV) — running hidden state in VMEM
"""
BK = q_ref.shape[3]
BV = v_ref.shape[3]
i_t = pl.program_id(4)

# ---- Precompute g_gamma ramp (constant across chunks) ----
# ---- Precompute constants (invariant across sub-steps) ----
causal_mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :]

if g_gamma_ref is not None:
head_idx = pl.program_id(1)
b_gamma = g_gamma_ref[head_idx].astype(jnp.float32)
b_g_gamma = b_gamma * (jnp.arange(BT) + 1).astype(jnp.float32) # [BT]
b_g_gamma_last = b_gamma * BT
g_gamma_diff = b_g_gamma[:, None] - b_g_gamma[None, :]
safe_g_gamma_diff = jnp.where(causal_mask, g_gamma_diff, 0.0)

# ---- Initialize hidden state on first chunk ----
# ---- Initialize hidden state on first iteration ----
@pl.when(i_t == 0)
def _init():
if h0_ref is not None:
scratch_ref[:, :] = h0_ref[0, 0].astype(jnp.float32)
else:
scratch_ref[:, :] = jnp.zeros((BK, BV), dtype=jnp.float32)

# ---- Load current chunk ----
b_q = q_ref[0, 0] # (BT, BK)
b_k = k_ref[0, 0] # (BT, BK)
b_v = v_ref[0, 0] # (BT, BV)
b_h = scratch_ref[...] # (BK, BV) — state *before* this chunk's update

# ===== Stage 1: Partial output for this K tile =====
# Partial inter-chunk: q_k @ h_k
partial_o = jnp.dot(
b_q, b_h,
preferred_element_type=jnp.float32,
) # (BT, BV)

# Partial intra-chunk attention: q_k @ k_k^T
partial_A = jnp.dot(
b_q, b_k.T,
preferred_element_type=jnp.float32,
) # (BT, BT)

# Apply scalar gate g
if g_ref is not None:
b_g = g_ref[0, 0, :, 0].astype(jnp.float32) # (BT,)
partial_o = partial_o * exp(b_g)[:, None]
g_diff = b_g[:, None] - b_g[None, :]
fwd_mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :]
safe_g_diff = jnp.where(fwd_mask, g_diff, 0.0)
partial_A = partial_A * exp(safe_g_diff)

# Apply per-head fixed decay g_gamma
if g_gamma_ref is not None:
partial_o = partial_o * exp(b_g_gamma)[:, None]
g_gamma_diff = b_g_gamma[:, None] - b_g_gamma[None, :]
fwd_mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :]
safe_g_gamma_diff = jnp.where(fwd_mask, g_gamma_diff, 0.0)
partial_A = partial_A * exp(safe_g_gamma_diff)

# Causal mask (lower triangular: i >= j)
mask = jnp.arange(BT)[:, None] >= jnp.arange(BT)[None, :]
partial_A = jnp.where(mask, partial_A, 0.0)

# Partial contribution (inter + intra); scale applied after K reduction
partial = partial_o + jnp.dot(
partial_A, b_v.astype(jnp.float32),
precision=jax.lax.Precision.HIGHEST,
preferred_element_type=jnp.float32,
) # (BT, BV)
o_ref[0, 0, 0] = partial.astype(o_ref.dtype)

# ===== Stage 2: Update hidden state for this K tile =====
b_v_upd = b_v

# Decay state and adjust v for scalar gate g
if g_ref is not None:
b_g_last = b_g[BT - 1]
scratch_ref[...] *= exp(b_g_last)
b_v_upd = (b_v_upd * exp(b_g_last - b_g)[:, None]).astype(b_v_upd.dtype)

# Decay state and adjust v for g_gamma
if g_gamma_ref is not None:
b_g_gamma_last = b_gamma * BT
scratch_ref[...] *= exp(b_g_gamma_last)
b_v_upd = (b_v_upd * exp(b_g_gamma_last - b_g_gamma)[:, None]).astype(b_v_upd.dtype)

# State update: h_k += k_k^T @ v_upd
scratch_ref[...] = scratch_ref[...] + jnp.dot(
b_k.astype(jnp.float32).T,
b_v_upd.astype(jnp.float32),
precision=jax.lax.Precision.HIGHEST,
preferred_element_type=jnp.float32,
)

# Store final hidden state
@pl.when(i_t == NT - 1)
# ---- Process NUM_STEPS sub-chunks per grid iteration ----
for step in range(NUM_STEPS):
offset = step * BT

# Load current sub-chunk
b_q = q_ref[0, 0, pl.ds(offset, BT), :] # (BT, BK)
b_k = k_ref[0, 0, pl.ds(offset, BT), :] # (BT, BK)
b_v = v_ref[0, 0, pl.ds(offset, BT), :] # (BT, BV)
b_h = scratch_ref[...] # (BK, BV)

# ===== Stage 1: Partial output for this K tile =====
partial_o = jnp.dot(
b_q, b_h,
preferred_element_type=jnp.float32,
) # (BT, BV)

partial_A = jnp.dot(
b_q, b_k.T,
preferred_element_type=jnp.float32,
) # (BT, BT)

# Apply scalar gate g
if g_ref is not None:
b_g = g_ref[0, 0, pl.ds(offset, BT), 0].astype(jnp.float32) # (BT,)
partial_o = partial_o * exp(b_g)[:, None]
g_diff = b_g[:, None] - b_g[None, :]
safe_g_diff = jnp.where(causal_mask, g_diff, 0.0)
partial_A = partial_A * exp(safe_g_diff)

# Apply per-head fixed decay g_gamma
if g_gamma_ref is not None:
partial_o = partial_o * exp(b_g_gamma)[:, None]
partial_A = partial_A * exp(safe_g_gamma_diff)

# Causal mask (lower triangular: i >= j)
partial_A = jnp.where(causal_mask, partial_A, 0.0)

# Partial contribution (inter + intra); scale applied after K reduction
partial = partial_o + jnp.dot(
partial_A, b_v.astype(jnp.float32),
precision=jax.lax.Precision.HIGHEST,
preferred_element_type=jnp.float32,
) # (BT, BV)
o_ref[0, 0, 0, pl.ds(offset, BT), :] = partial.astype(o_ref.dtype)

# ===== Stage 2: Update hidden state for this K tile =====
b_v_upd = b_v

# Decay state and adjust v for scalar gate g
if g_ref is not None:
b_g_last = b_g[BT - 1]
scratch_ref[...] *= exp(b_g_last)
b_v_upd = (b_v_upd * exp(b_g_last - b_g)[:, None]).astype(b_v_upd.dtype)

# Decay state and adjust v for g_gamma
if g_gamma_ref is not None:
scratch_ref[...] *= exp(b_g_gamma_last)
b_v_upd = (b_v_upd * exp(b_g_gamma_last - b_g_gamma)[:, None]).astype(b_v_upd.dtype)

# State update: h_k += k_k^T @ v_upd
scratch_ref[...] = scratch_ref[...] + jnp.dot(
b_k.astype(jnp.float32).T,
b_v_upd.astype(jnp.float32),
precision=jax.lax.Precision.HIGHEST,
preferred_element_type=jnp.float32,
)

# Store final hidden state at the last outer iteration
@pl.when(i_t == NT_OUTER - 1)
def _store_ht():
if ht_ref is not None:
ht_ref[0, 0] = scratch_ref[...].astype(ht_ref.dtype)
Expand All @@ -172,7 +175,7 @@ def _store_ht():

@functools.partial(
jax.jit,
static_argnames=("output_final_state", "chunk_size", "interpret"),
static_argnames=("output_final_state", "chunk_size", "num_steps", "interpret"),
)
def fused_chunk_fwd_kernel(
q: jax.Array,
Expand All @@ -185,13 +188,17 @@ def fused_chunk_fwd_kernel(
scale: float,
output_final_state: bool = False,
chunk_size: int = 64,
num_steps: int = 4,
interpret: bool = False,
) -> tuple[jax.Array, jax.Array | None]:
"""Pallas launcher for the fused chunk forward kernel.

Reshapes inputs to (B, H, T, D) layout, launches the fused kernel with
K-tiled and V-tiled grid, reduces partial outputs across K tiles, and
transposes the result back to (B, T, H, D).

Each grid iteration processes ``num_steps`` consecutive chunks, reducing
grid overhead by that factor compared to processing one chunk at a time.
"""
B, T, H, K = q.shape
V = v.shape[-1]
Expand All @@ -202,6 +209,12 @@ def fused_chunk_fwd_kernel(
NK = K // BK
NV = V // BV

# Auto-adjust num_steps: fall back to the largest valid divisor ≤ requested
while num_steps > 1 and NT % num_steps != 0:
num_steps //= 2
NT_OUTER = NT // num_steps
WINDOW = num_steps * BT # time-dimension tile size

# Transpose to (B, H, T, D) layout for the kernel
_q = q.transpose(0, 2, 1, 3) # (B, H, T, K)
_k = k.transpose(0, 2, 1, 3) # (B, H, T, K)
Expand All @@ -212,7 +225,7 @@ def fused_chunk_fwd_kernel(
_g = g.transpose(0, 2, 1) # (B, H, T)
_g = jnp.broadcast_to(_g[:, :, :, None], (B, H, T, 128)) # (B, H, T, 128)

grid = (B, H, NK, NV, NT)
grid = (B, H, NK, NV, NT_OUTER)

# ---- Index maps ----
def qk_map(b, h, ik, iv, t):
Expand All @@ -230,17 +243,17 @@ def g_map(b, h, ik, iv, t):
def o_map(b, h, ik, iv, t):
return (b, h, ik, t, iv)

# ---- Specs ----
# ---- Specs (time dimension tiles WINDOW = num_steps * BT) ----
smem = pltpu.ANY if interpret else pltpu.SMEM

spec_qk = pl.BlockSpec((1, 1, BT, BK), qk_map)
spec_v = pl.BlockSpec((1, 1, BT, BV), v_map)
spec_qk = pl.BlockSpec((1, 1, WINDOW, BK), qk_map)
spec_v = pl.BlockSpec((1, 1, WINDOW, BV), v_map)
spec_h0 = pl.BlockSpec((1, 1, BK, BV), state_map) if h0 is not None else None
spec_g = pl.BlockSpec((1, 1, BT, 128), g_map) if _g is not None else None
spec_g = pl.BlockSpec((1, 1, WINDOW, 128), g_map) if _g is not None else None
spec_gamma = pl.BlockSpec(memory_space=smem) if g_gamma is not None else None

# o_partial: (B, H, NK, T, V) — partial contribution per K tile, float32
spec_o = pl.BlockSpec((1, 1, 1, BT, BV), o_map)
spec_o = pl.BlockSpec((1, 1, 1, WINDOW, BV), o_map)
spec_ht = pl.BlockSpec((1, 1, BK, BV), state_map) if output_final_state else None

# ---- Output shapes ----
Expand All @@ -252,7 +265,10 @@ def o_map(b, h, ik, iv, t):

# ---- Launch ----
o_partial, ht = pl.pallas_call(
functools.partial(_fused_chunk_fwd_kernel, BT=BT, NT=NT),
functools.partial(
_fused_chunk_fwd_kernel,
BT=BT, NT=NT, NUM_STEPS=num_steps, NT_OUTER=NT_OUTER,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
grid=grid,
Expand Down Expand Up @@ -294,6 +310,7 @@ def fused_chunk_fwd(
scale: float | None = None,
use_ht: bool = False,
chunk_size: int = 64,
num_steps: int = 4,
interpret: bool | None = None,
) -> tuple[jax.Array, jax.Array | None]:
"""Fused chunk forward: compute output *o* and optionally final state *ht*.
Expand All @@ -307,25 +324,30 @@ def fused_chunk_fwd(
partial output contributions; the launcher reduces across K tiles and
applies the attention scale.

Each grid iteration processes ``num_steps`` consecutive chunks via a
for loop inside the kernel, reducing grid overhead by that factor.

Recurrence per chunk c (chunk_size positions):
h_c = h_{c-1} * decay(g, g_gamma) + k_c^T @ v_c
o_c = (q_c @ h_{c-1} + causal(q_c @ k_c^T) @ v_c) * scale

Args:
q: [B, T, H, K] queries.
k: [B, T, H, K] keys.
v: [B, T, H, V] values.
g: [B, T, H] chunk-local cumsum of scalar gate (optional).
g_gamma: [H] per-head fixed decay rate (optional).
h0: [B, H, K, V] initial hidden state (optional).
q: [B, T, H, K] -- queries.
k: [B, T, H, K] -- keys.
v: [B, T, H, V] -- values.
g: [B, T, H] -- chunk-local cumsum of scalar gate (optional).
g_gamma: [H] -- per-head fixed decay rate (optional).
h0: [B, H, K, V] -- initial hidden state (optional).
scale: attention scale. Defaults to K ** -0.5.
output_final_state: if True, also return the final hidden state.
use_ht: if True, also return the final hidden state.
chunk_size: block size. T must be divisible by chunk_size.
num_steps: number of consecutive chunks processed per grid iteration.
NT must be divisible by num_steps. Default 4 (from PR #74).
interpret: Pallas interpret mode. None = auto-detect.

Returns:
o: [B, T, H, V] output tensor.
ht: [B, H, K, V] or None final hidden state.
o: [B, T, H, V] -- output tensor.
ht: [B, H, K, V] or None -- final hidden state.
"""
B, T, H, K = q.shape
V = v.shape[-1]
Expand All @@ -345,6 +367,10 @@ def fused_chunk_fwd(
assert T % chunk_size == 0, (
f"Sequence length T={T} must be divisible by chunk_size={chunk_size}"
)
NT = T // chunk_size
# Auto-adjust num_steps: fall back to the largest valid divisor ≤ requested
while num_steps > 1 and NT % num_steps != 0:
num_steps //= 2
Comment on lines +372 to +373
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The manual fallback logic for num_steps is duplicated in both the kernel launcher and the public API function. This should be refactored into a shared utility function to ensure consistency and maintainability.

assert scale is not None
# =================== assert kernel requirements done ===================

Expand All @@ -361,5 +387,6 @@ def fused_chunk_fwd(
scale=scale,
output_final_state=use_ht,
chunk_size=chunk_size,
num_steps=num_steps,
interpret=interpret,
)
Loading
Loading