Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions models/deepseek/v4/decode_attention_csa.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ def attention_csa(

rope_cos_t = pl.create_tensor([T, ROPE_HEAD_DIM], dtype=pl.BF16)
rope_sin_t = pl.create_tensor([T, ROPE_HEAD_DIM], dtype=pl.BF16)
# Half-width unsigned inverse-RoPE snapshots for the split-half sparse_attn: the first
# HALF columns of the per-token cos/sin (one value per frequency), cast BF16->FP32 once per
# token so sparse_attn's per-head rope loop rotates with no in-loop cast and no gather.
rope_cos_half_t = pl.create_tensor([T, HALF_ROPE], dtype=pl.FP32)
rope_sin_half_t = pl.create_tensor([T, HALF_ROPE], dtype=pl.FP32)
step_cos = pl.create_tensor([B, HALF_ROPE], dtype=pl.FP32)
step_sin = pl.create_tensor([B, HALF_ROPE], dtype=pl.FP32)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="csa_rope_step"):
Expand All @@ -199,6 +204,10 @@ def attention_csa(
sin_row = pl.cast(pl.slice(freqs_sin, [1, ROPE_HEAD_DIM], [pos_b, 0]), target_type=pl.FP32)
rope_cos_t = pl.assemble(rope_cos_t, pl.cast(cos_row, target_type=pl.BF16), [t, 0])
rope_sin_t = pl.assemble(rope_sin_t, pl.cast(sin_row, target_type=pl.BF16), [t, 0])
rope_cos_half_t = pl.assemble(
rope_cos_half_t, pl.cast(pl.slice(freqs_cos, [1, HALF_ROPE], [pos_b, 0]), target_type=pl.FP32), [t, 0])
rope_sin_half_t = pl.assemble(
rope_sin_half_t, pl.cast(pl.slice(freqs_sin, [1, HALF_ROPE], [pos_b, 0]), target_type=pl.FP32), [t, 0])
Comment on lines +207 to +210

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of performing redundant pl.slice and pl.cast operations on freqs_cos and freqs_sin again, you can slice cos_row and sin_row directly, as they are already sliced and cast to pl.FP32 on lines 203-204. This avoids unnecessary overhead and is consistent with the implementation in decode_attention_swa.py.

Suggested change
rope_cos_half_t = pl.assemble(
rope_cos_half_t, pl.cast(pl.slice(freqs_cos, [1, HALF_ROPE], [pos_b, 0]), target_type=pl.FP32), [t, 0])
rope_sin_half_t = pl.assemble(
rope_sin_half_t, pl.cast(pl.slice(freqs_sin, [1, HALF_ROPE], [pos_b, 0]), target_type=pl.FP32), [t, 0])
rope_cos_half_t = pl.assemble(
rope_cos_half_t, cos_row[0 : 1, 0 : HALF_ROPE], [t, 0])
rope_sin_half_t = pl.assemble(
rope_sin_half_t, sin_row[0 : 1, 0 : HALF_ROPE], [t, 0])

step_cos = pl.assemble(step_cos, pl.cast(pl.slice(freqs_cos, [1, HALF_ROPE], [step_pos_b, 0]), target_type=pl.FP32), [b, 0])
step_sin = pl.assemble(step_sin, pl.cast(pl.slice(freqs_sin, [1, HALF_ROPE], [step_pos_b, 0]), target_type=pl.FP32), [b, 0])

Expand Down Expand Up @@ -357,8 +366,8 @@ def attention_csa(
cmp_block_table,
cmp_sparse_indices,
attn_sink,
rope_cos_t,
rope_sin_t,
rope_cos_half_t,
rope_sin_half_t,
wo_a,
wo_b,
wo_b_scale,
Expand Down Expand Up @@ -523,6 +532,9 @@ def golden_attention_csa(tensors):
freqs_sin = tensors["freqs_sin"]
rope_cos_t = freqs_cos[position_ids].contiguous()
rope_sin_t = freqs_sin[position_ids].contiguous()
# Half-width unsigned inverse-RoPE tables (first HALF columns, FP32) for the split-half sparse_attn golden.
rope_cos_half_t = freqs_cos[position_ids, :HALF_ROPE].float().contiguous()
rope_sin_half_t = freqs_sin[position_ids, :HALF_ROPE].float().contiguous()
first_pos = position_ids.reshape(B, S)[:, 0]
step_cos = freqs_cos[first_pos, :HALF_ROPE].float().contiguous()
step_sin = freqs_sin[first_pos, :HALF_ROPE].float().contiguous()
Expand Down Expand Up @@ -644,8 +656,8 @@ def golden_attention_csa(tensors):
"cmp_block_table": cmp_block_table,
"cmp_sparse_indices": sparse_topk,
"attn_sink": tensors["attn_sink"],
"freqs_cos": rope_cos_t,
"freqs_sin": rope_sin_t,
"rope_cos_half": rope_cos_half_t,
"rope_sin_half": rope_sin_half_t,
"wo_a": tensors["wo_a"],
"wo_b": tensors["wo_b"],
"wo_b_scale": tensors["wo_b_scale"],
Expand Down
23 changes: 19 additions & 4 deletions models/deepseek/v4/decode_attention_hca.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ def attention_hca(

rope_cos_t = pl.create_tensor([T, ROPE_HEAD_DIM], dtype=pl.BF16)
rope_sin_t = pl.create_tensor([T, ROPE_HEAD_DIM], dtype=pl.BF16)
# Half-width unsigned inverse-RoPE snapshots for the split-half sparse_attn_hca: the first
# HALF columns of the per-token cos/sin (one value per frequency), cast BF16->FP32 once per
# token so the per-head rope loop rotates with no in-loop cast and no gather. Same first-HALF
# columns qkv_proj_rope's forward rope consumes -> a single rope profile, no separate il table.
rope_cos_half_t = pl.create_tensor([T, ROPE_HEAD_DIM // 2], dtype=pl.FP32)
rope_sin_half_t = pl.create_tensor([T, ROPE_HEAD_DIM // 2], dtype=pl.FP32)
cmp_cos = pl.create_tensor([B, ROPE_HEAD_DIM // 2], dtype=pl.FP32)
cmp_sin = pl.create_tensor([B, ROPE_HEAD_DIM // 2], dtype=pl.FP32)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="hca_rope"):
Expand All @@ -160,6 +166,10 @@ def attention_hca(
step_sin_row = pl.cast(freqs_sin[pos_b : pos_b + 1, 0 : ROPE_HEAD_DIM], target_type=pl.FP32)
rope_cos_t[t : t + 1, 0 : ROPE_HEAD_DIM] = pl.cast(step_cos_row, target_type=pl.BF16, mode="rint")
rope_sin_t[t : t + 1, 0 : ROPE_HEAD_DIM] = pl.cast(step_sin_row, target_type=pl.BF16, mode="rint")
rope_cos_half_t[t : t + 1, 0 : ROPE_HEAD_DIM // 2] = pl.cast(
freqs_cos[pos_b : pos_b + 1, 0 : ROPE_HEAD_DIM // 2], target_type=pl.FP32)
rope_sin_half_t[t : t + 1, 0 : ROPE_HEAD_DIM // 2] = pl.cast(
freqs_sin[pos_b : pos_b + 1, 0 : ROPE_HEAD_DIM // 2], target_type=pl.FP32)
Comment on lines +169 to +172

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of performing redundant slicing and casting on freqs_cos and freqs_sin again, you can slice step_cos_row and step_sin_row directly, as they are already sliced and cast to pl.FP32 on lines 165-166. This avoids unnecessary overhead.

Suggested change
rope_cos_half_t[t : t + 1, 0 : ROPE_HEAD_DIM // 2] = pl.cast(
freqs_cos[pos_b : pos_b + 1, 0 : ROPE_HEAD_DIM // 2], target_type=pl.FP32)
rope_sin_half_t[t : t + 1, 0 : ROPE_HEAD_DIM // 2] = pl.cast(
freqs_sin[pos_b : pos_b + 1, 0 : ROPE_HEAD_DIM // 2], target_type=pl.FP32)
rope_cos_half_t[t : t + 1, 0 : ROPE_HEAD_DIM // 2] = step_cos_row[0 : 1, 0 : ROPE_HEAD_DIM // 2]
rope_sin_half_t[t : t + 1, 0 : ROPE_HEAD_DIM // 2] = step_sin_row[0 : 1, 0 : ROPE_HEAD_DIM // 2]


q = pl.create_tensor([T, H, HEAD_DIM], dtype=pl.BF16)
kv = pl.create_tensor([T, HEAD_DIM], dtype=pl.BF16)
Expand Down Expand Up @@ -262,8 +272,8 @@ def attention_hca(
cmp_block_table,
topk_all,
attn_sink,
rope_cos_t,
rope_sin_t,
rope_cos_half_t,
rope_sin_half_t,
wo_a,
wo_b,
wo_b_scale,
Expand Down Expand Up @@ -384,10 +394,15 @@ def golden_attention_hca(tensors):
freqs_sin = tensors["freqs_sin"]
rope_cos_T = torch.empty(T, rd, dtype=freqs_cos.dtype)
rope_sin_T = torch.empty(T, rd, dtype=freqs_sin.dtype)
# Half-width unsigned inverse-RoPE tables for the split-half sparse_attn_hca golden.
rope_cos_half_T = torch.empty(T, rd // 2, dtype=torch.float32)
rope_sin_half_T = torch.empty(T, rd // 2, dtype=torch.float32)
for t in range(T):
pos = int(position_ids[t].item())
rope_cos_T[t] = freqs_cos[pos]
rope_sin_T[t] = freqs_sin[pos]
rope_cos_half_T[t] = freqs_cos[pos, : rd // 2].float()
rope_sin_half_T[t] = freqs_sin[pos, : rd // 2].float()

# q + win kv (W8A8 q_proj)
q = torch.zeros(T, H, HEAD_DIM, dtype=torch.bfloat16)
Expand Down Expand Up @@ -480,8 +495,8 @@ def golden_attention_hca(tensors):
"cmp_block_table": cmp_block_table,
"cmp_sparse_indices": topk_all,
"attn_sink": tensors["attn_sink"],
"freqs_cos": rope_cos_T,
"freqs_sin": rope_sin_T,
"rope_cos_half": rope_cos_half_T,
"rope_sin_half": rope_sin_half_T,
"wo_a": tensors["wo_a"],
"wo_b": tensors["wo_b"],
"wo_b_scale": tensors["wo_b_scale"],
Expand Down
23 changes: 19 additions & 4 deletions models/deepseek/v4/decode_attention_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ def attention_swa(

rope_cos_t = pl.create_tensor([T, ROPE_HEAD_DIM], dtype=pl.BF16)
rope_sin_t = pl.create_tensor([T, ROPE_HEAD_DIM], dtype=pl.BF16)
# Half-width unsigned inverse-RoPE snapshots for the split-half sparse_attn_swa: the first
# HALF columns of the per-token cos/sin (one value per frequency), cast BF16->FP32 once per
# token so the per-head rope loop rotates with no in-loop cast and no gather. Same first-HALF
# columns qkv_proj_rope's forward rope consumes -> a single rope profile, no separate il table.
rope_cos_half_t = pl.create_tensor([T, ROPE_HEAD_DIM // 2], dtype=pl.FP32)
rope_sin_half_t = pl.create_tensor([T, ROPE_HEAD_DIM // 2], dtype=pl.FP32)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="swa_rope_step"):
for b in pl.parallel(B):
for s_idx in pl.range(S):
Expand All @@ -120,6 +126,10 @@ def attention_swa(
sin_row = pl.cast(pl.slice(freqs_sin, [1, ROPE_HEAD_DIM], [pos_b, 0]), target_type=pl.FP32)
rope_cos_t = pl.assemble(rope_cos_t, pl.cast(cos_row, target_type=pl.BF16, mode="rint"), [t, 0])
rope_sin_t = pl.assemble(rope_sin_t, pl.cast(sin_row, target_type=pl.BF16, mode="rint"), [t, 0])
rope_cos_half_t = pl.assemble(
rope_cos_half_t, cos_row[0 : 1, 0 : ROPE_HEAD_DIM // 2], [t, 0])
rope_sin_half_t = pl.assemble(
rope_sin_half_t, sin_row[0 : 1, 0 : ROPE_HEAD_DIM // 2], [t, 0])

q = pl.create_tensor([T, H, HEAD_DIM], dtype=pl.BF16)
kv = pl.create_tensor([T, HEAD_DIM], dtype=pl.BF16)
Expand Down Expand Up @@ -185,8 +195,8 @@ def attention_swa(
cmp_block_table,
sparse_topk,
attn_sink,
rope_cos_t,
rope_sin_t,
rope_cos_half_t,
rope_sin_half_t,
wo_a,
wo_b,
wo_b_scale,
Expand Down Expand Up @@ -300,10 +310,15 @@ def golden_attention_swa(tensors):
freqs_sin = tensors["freqs_sin"]
rope_cos_T = torch.empty(T, rd, dtype=freqs_cos.dtype)
rope_sin_T = torch.empty(T, rd, dtype=freqs_sin.dtype)
# Half-width unsigned inverse-RoPE tables for the split-half sparse_attn_swa golden.
rope_cos_half_T = torch.empty(T, rd // 2, dtype=torch.float32)
rope_sin_half_T = torch.empty(T, rd // 2, dtype=torch.float32)
for t in range(T):
pos = int(position_ids[t].item())
rope_cos_T[t] = freqs_cos[pos]
rope_sin_T[t] = freqs_sin[pos]
rope_cos_half_T[t] = freqs_cos[pos, : rd // 2].float()
rope_sin_half_T[t] = freqs_sin[pos, : rd // 2].float()

# q + win kv (model.py:495-504)
q = torch.zeros(T, H, HEAD_DIM, dtype=torch.bfloat16)
Expand Down Expand Up @@ -363,8 +378,8 @@ def golden_attention_swa(tensors):
"cmp_block_table": cmp_block_table_dummy,
"cmp_sparse_indices": sparse_topk_all,
"attn_sink": tensors["attn_sink"],
"freqs_cos": rope_cos_T,
"freqs_sin": rope_sin_T,
"rope_cos_half": rope_cos_half_T,
"rope_sin_half": rope_sin_half_T,
"wo_a": tensors["wo_a"],
"wo_b": tensors["wo_b"],
"wo_b_scale": tensors["wo_b_scale"],
Expand Down
48 changes: 23 additions & 25 deletions models/deepseek/v4/decode_compressor_ratio128.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,26 +194,22 @@ def compressor_ratio128(

kv_rope_norm = pooled_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : HEAD_DIM]
gamma_rope = norm_w_2d[:, NOPE_HEAD_DIM : HEAD_DIM]
# A3 interleaved swap-gather (same form as kv_rope_fused in qkv_proj_rope),
# replacing the de-interleave gather + rotate + re-interleave scatter. gamma+inv_rms
# are folded into rope_normed BEFORE the swap, so the swapped lane n[j^1] correctly
# carries gamma[j^1]; inv_rms is per-row so it commutes. swap_idx (j^1), sign
# ([-1,+1,...]) and dup_idx (j>>1) are built IN-KERNEL from pl.arange; cos_il/sin_il
# are dup-gathered from the per-batch cos/sin rows. normed_kv is FP32 -> write directly.
# out[j] = n[j]*cos_il[j] + n[j^1]*sign[j]*sin_il[j]
rope_normed = pl.col_expand_mul(pl.row_expand_mul(kv_rope_norm, inv_rms), gamma_rope)
rope_ones = pl.full([RMS_TILE, ROPE_HEAD_DIM], dtype=pl.FP32, value=1.0)
rope_col = pl.col_expand_mul(rope_ones, pl.cast(pl.arange(0, [1, ROPE_HEAD_DIM], dtype=pl.INT32), target_type=pl.FP32))
rope_dup_f = pl.cast(pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
rope_dup_idx = pl.cast(rope_dup_f, target_type=pl.INT32) # j>>1
rope_lane = pl.sub(rope_col, pl.mul(rope_dup_f, 2.0)) # j%2
rope_swap_idx = pl.cast(pl.sub(pl.add(rope_col, 1.0), pl.mul(rope_lane, 2.0)), target_type=pl.INT32) # j^1
rope_sign = pl.sub(pl.mul(rope_lane, 2.0), 1.0) # [-1,+1,...]
cos_il = pl.gather(cos_b, dim=-1, index=rope_dup_idx)
sin_il = pl.gather(sin_b, dim=-1, index=rope_dup_idx)
swapped = pl.gather(rope_normed, dim=-1, index=rope_swap_idx)
rope_rot = pl.add(pl.mul(rope_normed, cos_il), pl.mul(pl.mul(swapped, rope_sign), sin_il))
normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : HEAD_DIM] = rope_rot
# Split-half (NeoX) forward RoPE, gather-free. rope segment = [x_lo | x_hi] = dims
# [0:rh | rh:ROPE_HEAD_DIM]; partner of lane k is k+rh (contiguous slice, no j^1 gather
# and no j>>1 dup-gather -- cos_b/sin_b are already half-width). gamma is per-column so
# it does NOT commute with the rotation: fold gamma_lo onto x_lo, gamma_hi onto x_hi
# BEFORE rotating; inv_rms is per-row and commutes.
# out_lo = x_lo*cos - x_hi*sin ; out_hi = x_lo*sin + x_hi*cos
gamma_lo = gamma_rope[0:1, 0 : ROPE_HEAD_DIM // 2]
gamma_hi = gamma_rope[0:1, ROPE_HEAD_DIM // 2 : ROPE_HEAD_DIM]
kv_lo = kv_rope_norm[0 : RMS_TILE, 0 : ROPE_HEAD_DIM // 2]
kv_hi = kv_rope_norm[0 : RMS_TILE, ROPE_HEAD_DIM // 2 : ROPE_HEAD_DIM]
lo_n = pl.col_expand_mul(pl.row_expand_mul(kv_lo, inv_rms), gamma_lo)
hi_n = pl.col_expand_mul(pl.row_expand_mul(kv_hi, inv_rms), gamma_hi)
out_lo = pl.sub(pl.mul(lo_n, cos_b), pl.mul(hi_n, sin_b)) # x_lo*c - x_hi*s
out_hi = pl.add(pl.mul(lo_n, sin_b), pl.mul(hi_n, cos_b)) # x_lo*s + x_hi*c
normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : NOPE_HEAD_DIM + ROPE_HEAD_DIM // 2] = out_lo
normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM + ROPE_HEAD_DIM // 2 : HEAD_DIM] = out_hi

kv_flat = pl.reshape(kv, [bs, HEAD_DIM])
cmp_flat_rows = cmp_block_num * BLOCK_SIZE
Expand Down Expand Up @@ -367,13 +363,15 @@ def rmsnorm(x, w):
continue
kv_b = rmsnorm(pooled[b : b + 1], norm_w)

x_pair = kv_b[..., -rd:].unflatten(-1, (-1, 2))
x0, x1 = x_pair[..., 0], x_pair[..., 1]
# Split-half (NeoX) forward RoPE: lo = first rd/2 rope dims, hi = last rd/2.
rope = kv_b[..., -rd:]
x_lo = rope[..., : rd // 2]
x_hi = rope[..., rd // 2 :]
cos_v, sin_v = cos[b].view(-1), sin[b].view(-1)
y0 = x0 * cos_v - x1 * sin_v
y1 = x0 * sin_v + x1 * cos_v
y_lo = x_lo * cos_v - x_hi * sin_v
y_hi = x_lo * sin_v + x_hi * cos_v

kv_b = torch.cat([kv_b[..., :-rd], torch.stack([y0, y1], dim=-1).flatten(-2)], dim=-1)
kv_b = torch.cat([kv_b[..., :-rd], y_lo, y_hi], dim=-1)

# Kernel writes pooled result only to kv[:, 0, :]; leave kv[:, 1:, :] = 0.
tensors["kv"][b : b + 1, 0:1, :] = kv_b
Expand Down
48 changes: 23 additions & 25 deletions models/deepseek/v4/decode_compressor_ratio4.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,26 +195,22 @@ def compressor_ratio4(

kv_rope_norm = pooled_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : HEAD_DIM]
gamma_rope = norm_w_2d[:, NOPE_HEAD_DIM : HEAD_DIM]
# A3 interleaved swap-gather (same form as kv_rope_fused in qkv_proj_rope),
# replacing the de-interleave gather + rotate + re-interleave scatter. gamma+inv_rms
# are folded into rope_normed BEFORE the swap, so the swapped lane n[j^1] correctly
# carries gamma[j^1]; inv_rms is per-row so it commutes. swap_idx (j^1), sign
# ([-1,+1,...]) and dup_idx (j>>1) are built IN-KERNEL from pl.arange; cos_il/sin_il
# are dup-gathered from the per-batch cos/sin rows. normed_kv is FP32 -> write directly.
# out[j] = n[j]*cos_il[j] + n[j^1]*sign[j]*sin_il[j]
rope_normed = pl.col_expand_mul(pl.row_expand_mul(kv_rope_norm, inv_rms), gamma_rope)
rope_ones = pl.full([RMS_TILE, ROPE_HEAD_DIM], dtype=pl.FP32, value=1.0)
rope_col = pl.col_expand_mul(rope_ones, pl.cast(pl.arange(0, [1, ROPE_HEAD_DIM], dtype=pl.INT32), target_type=pl.FP32))
rope_dup_f = pl.cast(pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
rope_dup_idx = pl.cast(rope_dup_f, target_type=pl.INT32) # j>>1
rope_lane = pl.sub(rope_col, pl.mul(rope_dup_f, 2.0)) # j%2
rope_swap_idx = pl.cast(pl.sub(pl.add(rope_col, 1.0), pl.mul(rope_lane, 2.0)), target_type=pl.INT32) # j^1
rope_sign = pl.sub(pl.mul(rope_lane, 2.0), 1.0) # [-1,+1,...]
cos_il = pl.gather(cos_b, dim=-1, index=rope_dup_idx)
sin_il = pl.gather(sin_b, dim=-1, index=rope_dup_idx)
swapped = pl.gather(rope_normed, dim=-1, index=rope_swap_idx)
rope_rot = pl.add(pl.mul(rope_normed, cos_il), pl.mul(pl.mul(swapped, rope_sign), sin_il))
normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : HEAD_DIM] = rope_rot
# Split-half (NeoX) forward RoPE, gather-free. rope segment = [x_lo | x_hi] = dims
# [0:rh | rh:ROPE_HEAD_DIM]; partner of lane k is k+rh (contiguous slice, no j^1 gather
# and no j>>1 dup-gather -- cos_b/sin_b are already half-width). gamma is per-column so
# it does NOT commute with the rotation: fold gamma_lo onto x_lo, gamma_hi onto x_hi
# BEFORE rotating; inv_rms is per-row and commutes.
# out_lo = x_lo*cos - x_hi*sin ; out_hi = x_lo*sin + x_hi*cos
gamma_lo = gamma_rope[0:1, 0 : ROPE_HEAD_DIM // 2]
gamma_hi = gamma_rope[0:1, ROPE_HEAD_DIM // 2 : ROPE_HEAD_DIM]
kv_lo = kv_rope_norm[0 : RMS_TILE, 0 : ROPE_HEAD_DIM // 2]
kv_hi = kv_rope_norm[0 : RMS_TILE, ROPE_HEAD_DIM // 2 : ROPE_HEAD_DIM]
lo_n = pl.col_expand_mul(pl.row_expand_mul(kv_lo, inv_rms), gamma_lo)
hi_n = pl.col_expand_mul(pl.row_expand_mul(kv_hi, inv_rms), gamma_hi)
out_lo = pl.sub(pl.mul(lo_n, cos_b), pl.mul(hi_n, sin_b)) # x_lo*c - x_hi*s
out_hi = pl.add(pl.mul(lo_n, sin_b), pl.mul(hi_n, cos_b)) # x_lo*s + x_hi*c
normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM : NOPE_HEAD_DIM + ROPE_HEAD_DIM // 2] = out_lo
normed_kv[batch_base : batch_base + RMS_TILE, NOPE_HEAD_DIM + ROPE_HEAD_DIM // 2 : HEAD_DIM] = out_hi

for batch_base_idx in pl.spmd(B // RMS_TILE, name_hint="kv_and_cache_write"):
batch_base = batch_base_idx * RMS_TILE
Expand Down Expand Up @@ -361,13 +357,15 @@ def rmsnorm(x, w):
boundary_s = ratio - 1 - (first_pos % ratio)
kv_b = rmsnorm(pooled[b : b + 1], norm_w)

x_pair = kv_b[..., -rd:].unflatten(-1, (-1, 2))
x0, x1 = x_pair[..., 0], x_pair[..., 1]
# Split-half (NeoX) forward RoPE: lo = first rd/2 rope dims, hi = last rd/2.
rope = kv_b[..., -rd:]
x_lo = rope[..., : rd // 2]
x_hi = rope[..., rd // 2 :]
cos_v, sin_v = cos[b].view(-1), sin[b].view(-1)
y0 = x0 * cos_v - x1 * sin_v
y1 = x0 * sin_v + x1 * cos_v
y_lo = x_lo * cos_v - x_hi * sin_v
y_hi = x_lo * sin_v + x_hi * cos_v

kv_b = torch.cat([kv_b[..., :-rd], torch.stack([y0, y1], dim=-1).flatten(-2)], dim=-1)
kv_b = torch.cat([kv_b[..., :-rd], y_lo, y_hi], dim=-1)

# Kernel writes pooled result only to kv[:, 0, :]; leave kv[:, 1:, :] = 0
# so the golden matches its [B, S, HEAD_DIM] zero-init.
Expand Down
4 changes: 4 additions & 0 deletions models/deepseek/v4/decode_layer_dp_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,10 @@ def build_tensor_specs(start_pos=None, layer_id=10):
"csa": csa_specs,
}[attention_kind]

# Split-half (NeoX) RoPE: the sparse_attn kernels read the half-width unsigned cos/sin
# straight from the first HALF columns of freqs_cos/freqs_sin (the caller slices them), so
# there is no separate interleaved/sign-folded table -- one rope profile for qkv + sparse_attn.

replicated_attention = {
"hc_attn_fn",
"hc_attn_scale",
Expand Down
Loading