You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Those tensors are passed to indexer(...) as step_cos / step_sin.
In models/deepseek/v4/decode_indexer.py, the indexer API also declares cos and sin as [B, ROPE_HEAD_DIM // 2]. During QR RoPE it derives only batch_idx = token_idx // S and reads:
Therefore all S tokens in the same batch row use the RoPE angle for start_pos[b]. For S > 1, token s = 1 should use start_pos[b] + 1, token s = 2 should use start_pos[b] + 2, etc. The learned sparse-indexer QR path instead rotates all of them with the same position.
Concrete example:
B = 1
S = 2
start_pos[0] = 127
Expected indexer RoPE positions:
token 0 -> 127
token 1 -> 128
Current indexer RoPE positions:
token 0 -> 127
token 1 -> 127
This can make CSA sparse index selection wrong for multi-token decode/MTP/chunk decode. It may be hidden by tests that only cover S = 1 or by goldens that share the same shortcut.
#351 covered moving from scalar start_pos to per-batch decode positions. This issue is narrower: even with start_pos: [B], the CSA indexer still needs per-token RoPE positions within each batch row when S > 1.
#383 defines the decode-side metadata contract with position_ids[MAX_TOKENS]. That is the cleaner long-term fix: indexer RoPE should use token-major position_ids[t], not only batch-major start_pos[b].
Proposed fix
Change the indexer RoPE input contract from batch-major cos/sin: [B, HALF_ROPE] to token-major RoPE metadata, for example:
position_ids: [T] plus global freqs_cos/freqs_sin, or
precomputed cos/sin: [T, HALF_ROPE].
In decode_attention_csa.py, pass per-token RoPE rows to indexer, consistent with the main attention RoPE path.
In decode_indexer.py, use token_idx / token-major position metadata instead of batch_idx = token_idx // S for QR RoPE.
Add a boundary test with S = 2 and start_pos near a visible RoPE boundary, e.g. start_pos = 127, and assert token 0 and token 1 consume different RoPE rows.
This issue is based on static source inspection. No NPU/device repro was run. The affected path is the CSA learned indexer QR RoPE path; the main attention RoPE setup in decode_attention_csa.py already has per-token rows, so the fix should make the indexer path match that behavior.
Diagnosis
pypto-lib DeepSeek V4 CSA decode/indexer path — this is a static correctness issue in the CSA learned-indexer RoPE inputs when
S > 1.Description
models/deepseek/v4/decode_attention_csa.pycorrectly builds per-token RoPE rows for the main attention path:But the CSA indexer input is only batch-major:
Those tensors are passed to
indexer(...)asstep_cos/step_sin.In
models/deepseek/v4/decode_indexer.py, the indexer API also declarescosandsinas[B, ROPE_HEAD_DIM // 2]. During QR RoPE it derives onlybatch_idx = token_idx // Sand reads:Therefore all
Stokens in the same batch row use the RoPE angle forstart_pos[b]. ForS > 1, tokens = 1should usestart_pos[b] + 1, tokens = 2should usestart_pos[b] + 2, etc. The learned sparse-indexer QR path instead rotates all of them with the same position.Concrete example:
Expected indexer RoPE positions:
Current indexer RoPE positions:
This can make CSA sparse index selection wrong for multi-token decode/MTP/chunk decode. It may be hidden by tests that only cover
S = 1or by goldens that share the same shortcut.Relation to existing issues
Related: #351, #383.
#351 covered moving from scalar
start_posto per-batch decode positions. This issue is narrower: even withstart_pos: [B], the CSA indexer still needs per-token RoPE positions within each batch row whenS > 1.#383 defines the decode-side metadata contract with
position_ids[MAX_TOKENS]. That is the cleaner long-term fix: indexer RoPE should use token-majorposition_ids[t], not only batch-majorstart_pos[b].Proposed fix
cos/sin: [B, HALF_ROPE]to token-major RoPE metadata, for example:position_ids: [T]plus globalfreqs_cos/freqs_sin, orcos/sin: [T, HALF_ROPE].decode_attention_csa.py, pass per-token RoPE rows toindexer, consistent with the main attention RoPE path.decode_indexer.py, usetoken_idx/ token-major position metadata instead ofbatch_idx = token_idx // Sfor QR RoPE.S = 2andstart_posnear a visible RoPE boundary, e.g.start_pos = 127, and assert token 0 and token 1 consume different RoPE rows.position_ids[t]so this also works for heterogeneous/continuous batching.Environment
2c2cc08(branch: temp)468a51eb(branch: main)Host Platform
Linux (aarch64)
Additional Context
This issue is based on static source inspection. No NPU/device repro was run. The affected path is the CSA learned indexer QR RoPE path; the main attention RoPE setup in
decode_attention_csa.pyalready has per-token rows, so the fix should make the indexer path match that behavior.