Skip to content

[Bug] DeepSeek V4 CSA decode indexer reuses start_pos RoPE for all S tokens #462

Description

@high-cloud

Diagnosis

pypto-lib DeepSeek V4 CSA decode/indexer path — this is a static correctness issue in the CSA learned-indexer RoPE inputs when S > 1.

Description

models/deepseek/v4/decode_attention_csa.py correctly builds per-token RoPE rows for the main attention path:

for s in pl.range(S):
    t = b * S + s
    pos_b = pl.cast(start_pos_b + s, pl.INDEX)
    cos_row = pl.cast(pl.slice(freqs_cos, [1, ROPE_HEAD_DIM], [pos_b, 0]), target_type=pl.FP32)
    sin_row = pl.cast(pl.slice(freqs_sin, [1, ROPE_HEAD_DIM], [pos_b, 0]), target_type=pl.FP32)
    rope_cos_t = pl.assemble(rope_cos_t, pl.cast(cos_row, target_type=pl.BF16), [t, 0])
    rope_sin_t = pl.assemble(rope_sin_t, pl.cast(sin_row, target_type=pl.BF16), [t, 0])

But the CSA indexer input is only batch-major:

step_cos = pl.assemble(step_cos, pl.cast(pl.slice(freqs_cos, [1, HALF_ROPE], [step_pos_b, 0]), target_type=pl.FP32), [b, 0])
step_sin = pl.assemble(step_sin, pl.cast(pl.slice(freqs_sin, [1, HALF_ROPE], [step_pos_b, 0]), target_type=pl.FP32), [b, 0])

Those tensors are passed to indexer(...) as step_cos / step_sin.

In models/deepseek/v4/decode_indexer.py, the indexer API also declares cos and sin as [B, ROPE_HEAD_DIM // 2]. During QR RoPE it derives only batch_idx = token_idx // S and reads:

cos_b = cos[batch_idx : batch_idx + 1, 0 : ROPE_HEAD_DIM // 2]
sin_b = sin[batch_idx : batch_idx + 1, 0 : ROPE_HEAD_DIM // 2]

Therefore all S tokens in the same batch row use the RoPE angle for start_pos[b]. For S > 1, token s = 1 should use start_pos[b] + 1, token s = 2 should use start_pos[b] + 2, etc. The learned sparse-indexer QR path instead rotates all of them with the same position.

Concrete example:

B = 1
S = 2
start_pos[0] = 127

Expected indexer RoPE positions:

token 0 -> 127
token 1 -> 128

Current indexer RoPE positions:

token 0 -> 127
token 1 -> 127

This can make CSA sparse index selection wrong for multi-token decode/MTP/chunk decode. It may be hidden by tests that only cover S = 1 or by goldens that share the same shortcut.

Relation to existing issues

Related: #351, #383.

#351 covered moving from scalar start_pos to per-batch decode positions. This issue is narrower: even with start_pos: [B], the CSA indexer still needs per-token RoPE positions within each batch row when S > 1.

#383 defines the decode-side metadata contract with position_ids[MAX_TOKENS]. That is the cleaner long-term fix: indexer RoPE should use token-major position_ids[t], not only batch-major start_pos[b].

Proposed fix

  1. Change the indexer RoPE input contract from batch-major cos/sin: [B, HALF_ROPE] to token-major RoPE metadata, for example:
    • position_ids: [T] plus global freqs_cos/freqs_sin, or
    • precomputed cos/sin: [T, HALF_ROPE].
  2. In decode_attention_csa.py, pass per-token RoPE rows to indexer, consistent with the main attention RoPE path.
  3. In decode_indexer.py, use token_idx / token-major position metadata instead of batch_idx = token_idx // S for QR RoPE.
  4. Add a boundary test with S = 2 and start_pos near a visible RoPE boundary, e.g. start_pos = 127, and assert token 0 and token 1 consume different RoPE rows.
  5. Prefer aligning with [Feature] Define DeepSeek V4 decode KV cache paged metadata contract for vLLM-style serving #383 by using position_ids[t] so this also works for heterogeneous/continuous batching.

Environment

Component Version
pypto-lib 2c2cc08 (branch: temp)
pypto 468a51eb (branch: main)
simpler not detected
ptoas ptoas 0.43
CANN not detected

Host Platform

Linux (aarch64)

Additional Context

This issue is based on static source inspection. No NPU/device repro was run. The affected path is the CSA learned indexer QR RoPE path; the main attention RoPE setup in decode_attention_csa.py already has per-token rows, so the fix should make the indexer path match that behavior.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions