Summary
We need to extend the DeepSeek V4 decode path to support the following serving target:
- Up to 8k prompt/input context.
- Then generate 512 output tokens.
- For now, treat this as 512 separate
decode_fwd invocations. Although the current static decode shape has DECODE_SEQ = 2 for MTP, the upper layer cannot consume the second MTP token yet, so the first milestone should only commit one accepted token per decode call.
- The maximum absolute position that must be valid is therefore:
To avoid repeatedly resizing around this boundary, the proposed static capacity target is:
This issue is not asking for 16k full raw KV. Decode should continue to keep only sliding-window raw KV. The long-context requirement is mainly about compressed history, indexer history, RoPE capacity, and state/cache mapping correctness for high absolute positions.
Current Code Path
The current decode full-forward path is roughly:
models/deepseek/v4/decode_fwd.py
-> layer 0/1: SWA attention + MoE
-> layer 2/4/.../42: CSA attention + MoE
-> layer 3/5/.../41: HCA attention + MoE
-> hc_head
Each attention module is roughly:
hc_pre
-> rmsnorm
-> qkv_proj_rope
-> path-specific compressor/indexer/topk/sparse_attn
-> delayed KV/cache writeback
-> hc_post
Relevant files include:
models/deepseek/v4/config.py
models/deepseek/v4/decode_fwd.py
models/deepseek/v4/decode_attention_swa.py
models/deepseek/v4/decode_attention_hca.py
models/deepseek/v4/decode_attention_csa.py
models/deepseek/v4/decode_sparse_attn.py
models/deepseek/v4/decode_sparse_attn_hca.py
models/deepseek/v4/decode_indexer.py
models/deepseek/v4/decode_indexer_compressor.py
models/deepseek/v4/decode_compressor_ratio4.py
models/deepseek/v4/decode_compressor_ratio128.py
Required Capabilities
1. Increase decode position capacity to 16384
FLASH.max_position_embeddings and all decode-side shapes derived from MAX_SEQ_LEN should support at least 16384 positions.
This needs to cover:
freqs_cos / freqs_sin
position_ids
IDX_KV_LEN = MAX_SEQ_LEN // 4
SCORE_LEN
- sparse-attention topk width and padded width
- compressor state block-table width
- golden/fixture high-position cases
2. Support one committed token per decode call
The first milestone should support 512 separate decode calls, each committing exactly one accepted token.
There are two possible approaches:
- Add/support a
DECODE_SEQ = 1 decode configuration.
- Keep the static
S = 2 shape, but make the second token a non-committed placeholder via metadata:
ori_slot_mapping[second_token] = -1
cmp_slot_mapping[second_token] = -1
idx_slot_mapping[second_token] = -1
state_slot_mapping[second_token] = -1
inner_state_slot_mapping[second_token] = -1
The second token must not mutate raw KV, compressed KV, indexer KV, or compressor state. Sparse-attention overlay/topk visibility must also respect the single committed token semantics.
3. Increase CSA compressed KV capacity
CSA ratio-4 compressed KV needs enough slots for the full target capacity:
compressed_slots = 16384 / 4 = 4096
CMP_MAX_BLOCKS = 4096 / 128 = 32
So CSA cmp_kv / cmp_block_table should support at least CMP_MAX_BLOCKS = 32.
For the immediate 8k+512 target, about 8704 / 4 = 2176 compressed slots are needed, but using 32 blocks keeps the contract aligned with MAX_SEQ_LEN = 16384.
HCA ratio-128 needs far fewer compressed slots for this target, but its sparse-attention and block-table semantics should still work at high absolute positions.
4. Synchronize indexer cache and score capacity with MAX_SEQ_LEN
CSA indexer capacity at 16k should be:
IDX_KV_LEN = 16384 / 4 = 4096
The existing IDX_CACHE_MAX_BLOCKS = 64 provides 8192 rows and should be physically large enough, but the following must be audited and updated:
SCORE_LEN
score [B, S, SCORE_LEN]
topk_idxs [B, S, SCORE_LEN]
- sort / merge-sort stages for
SCORE_LEN = 4096
- topk selection with
IDX_TOPK = 512
- visibility clamping by
kv_seq_lens
If the current indexer sort pipeline assumes SCORE_LEN = 2048, it needs a 4096-capable path.
5. Implement full compressor state paging semantics
This is the most important contract gap.
We should not store full 16k raw compressor state physically. Instead, the natural serving contract is:
logical state block table covers absolute-position space
physical state cache is a small rolling/paged pool
state_slot_mapping points each current token to its physical state row
The current implementation tends to use one COMPRESS_STATE_MAX_BLOCKS value for both:
- logical block-table width, and
- physical state tensor size.
For 8k/16k decode, these should be decoupled.
Suggested contract:
STATE_TABLE_MAX_BLOCKS = ceil(MAX_SEQ_LEN / STATE_BLOCK_SIZE)
STATE_PHYSICAL_BLOCKS = enough for the active rolling compressor window
Then:
compress_state_block_table[b, logical_block] -> physical_block or -1
state_slot_mapping[t] -> physical state row for the current token, or -1
This applies to:
- HCA ratio-128 main compressor
compress_state_block_table
state_slot_mapping
- CSA ratio-4 main compressor
compress_state_block_table
state_slot_mapping
- CSA inner indexer compressor
inner_compress_state_block_table
inner_state_slot_mapping
The compressor only needs to read the compression boundary window that is relevant for the current decode step, but it must be able to address that window by high absolute positions such as 8192 or 8703.
6. Ensure sparse-attention raw indices map correctly at long context
The decode sparse-attention raw-index contract is currently roughly:
[0, WIN) historical sliding-window ring KV
[WIN, WIN + S) current MTP overlay KV
[WIN + S, ...) compressed KV slot
For 8k+512 decode we need to ensure:
ori_block_table still manages only the sliding-window ring KV.
cmp_block_table covers compressed slots up to 4095.
idx_block_table covers indexer compressed slots up to 4095.
- compressed raw indices are slot-based, not confused with absolute positions.
position_ids and kv_seq_lens correctly limit visible compressed slots.
- future compressed slots are never visible.
- future/current overlay tokens do not incorrectly overwrite historical ring slots inside the same decode step.
7. Use repeated decode_fwd calls first
The first milestone should be a host/Python loop:
for step in range(512):
prepare metadata for one accepted token
run decode_fwd once
carry mutated caches/states into the next step
A monolithic 512-step JIT graph is explicitly out of scope for this milestone.
8. Add final logits/sampling later
The first milestone can validate hidden output and cache/state side effects. The final closed-loop generation path can be added later:
final RMSNorm -> lm_head -> logits -> argmax/sampling -> next input_ids
Proposed Implementation Phases
Phase 1: Capacity constants
- Set decode
MAX_SEQ_LEN capacity to 16384.
- Set CSA compressed KV capacity to
CMP_MAX_BLOCKS = 32.
- Update
IDX_KV_LEN, SCORE_LEN, and all derived static shapes.
- Keep raw KV sliding-window-only.
Phase 2: Long-position metadata lowering
Add or update metadata generation for high absolute positions:
position_ids
kv_seq_lens
ori_slot_mapping
cmp_slot_mapping
idx_slot_mapping
state_slot_mapping
inner_state_slot_mapping
ori_block_table
cmp_block_table
idx_block_table
compress_state_block_table
inner_compress_state_block_table
Phase 3: Compressor state paging
Decouple logical state-table width from physical state-cache size.
Implement and test:
- ratio-128 state paging
- ratio-4 main state paging
- ratio-4 inner indexer state paging
Phase 4: Single-token commit mode
Support one committed token per decode invocation, either with S=1 or with S=2 plus inactive metadata for the second token.
Phase 5: 512-step decode driver
Add a test/driver that repeatedly calls decode_fwd 512 times while carrying cache/state tensors across calls.
Required Test Matrix
Basic position-capacity tests
start_pos = 0
start_pos = 127
start_pos = 128
start_pos = 8191
start_pos = 8192
start_pos = 8703
start_pos = 16383
All should compile/run without RoPE, position, cache, or state mapping OOB.
CSA ratio-4 compressed boundary tests
start_pos = 3 -> 4: first compressed write
start_pos = 127 -> 128: compressed slot 31/32 boundary
start_pos = 511 -> 512: compressed KV block boundary
start_pos = 4095 -> 4096: old 1024-slot boundary
start_pos = 8191 -> 8192
start_pos = 8703
start_pos = 16383
HCA ratio-128 compressed boundary tests
start_pos = 127 -> 128: first ratio-128 compressed write
start_pos = 255 -> 256
start_pos = 8191 -> 8192
start_pos = 8703
start_pos = 16383
Indexer long-context tests
kv_seq_len = 8192
kv_seq_len = 8704
kv_seq_len = 16384
Required checks:
- visible compressed slots are computed correctly
- score/topk buffers do not OOB
IDX_TOPK = 512 still returns only top 512 slots
- compressed slots beyond the old 2048-score range are valid
State paging tests
- non-identity state block table
- ring-reused physical state blocks
state_slot_mapping != absolute_pos
state_slot_mapping = -1 no-write token
- ratio-4 main state block-size boundary
- ratio-4 inner state block-size boundary
- ratio-128 state block-size boundary
- high-position state reads at 8192, 8704, and 16383
Sparse-attention raw-index tests
- historical ring raw index
[0, WIN)
- current overlay raw index
[WIN, WIN + S)
- compressed raw index
[WIN + S, ...)
- future overlay is not visible
- future compressed slot is not visible
- compressed slot > 1024 is readable
- permuted
cmp_block_table still works
- permuted
idx_block_table still works
512-step decode driver tests
- single request, prompt length 8192, run 512 decode calls
- mixed batch with different start positions per request
- requests near different boundary classes:
- 127/128
- 511/512
- 4095/4096
- 8191/8192
- 8703
- after 512 calls, cache/state metadata should remain valid and deterministic
- each call commits exactly one token
Non-goals For The First Milestone
- full raw KV for 8k/16k context
- monolithic 512-step JIT graph
- fully dynamic serving allocator
- final logits/sampling loop
- performance optimization
The first milestone is correctness and metadata-contract completeness for long-position decode.
Summary
We need to extend the DeepSeek V4 decode path to support the following serving target:
decode_fwdinvocations. Although the current static decode shape hasDECODE_SEQ = 2for MTP, the upper layer cannot consume the second MTP token yet, so the first milestone should only commit one accepted token per decode call.To avoid repeatedly resizing around this boundary, the proposed static capacity target is:
This issue is not asking for 16k full raw KV. Decode should continue to keep only sliding-window raw KV. The long-context requirement is mainly about compressed history, indexer history, RoPE capacity, and state/cache mapping correctness for high absolute positions.
Current Code Path
The current decode full-forward path is roughly:
Each attention module is roughly:
Relevant files include:
models/deepseek/v4/config.pymodels/deepseek/v4/decode_fwd.pymodels/deepseek/v4/decode_attention_swa.pymodels/deepseek/v4/decode_attention_hca.pymodels/deepseek/v4/decode_attention_csa.pymodels/deepseek/v4/decode_sparse_attn.pymodels/deepseek/v4/decode_sparse_attn_hca.pymodels/deepseek/v4/decode_indexer.pymodels/deepseek/v4/decode_indexer_compressor.pymodels/deepseek/v4/decode_compressor_ratio4.pymodels/deepseek/v4/decode_compressor_ratio128.pyRequired Capabilities
1. Increase decode position capacity to 16384
FLASH.max_position_embeddingsand all decode-side shapes derived fromMAX_SEQ_LENshould support at least 16384 positions.This needs to cover:
freqs_cos/freqs_sinposition_idsIDX_KV_LEN = MAX_SEQ_LEN // 4SCORE_LEN2. Support one committed token per decode call
The first milestone should support 512 separate decode calls, each committing exactly one accepted token.
There are two possible approaches:
DECODE_SEQ = 1decode configuration.S = 2shape, but make the second token a non-committed placeholder via metadata:The second token must not mutate raw KV, compressed KV, indexer KV, or compressor state. Sparse-attention overlay/topk visibility must also respect the single committed token semantics.
3. Increase CSA compressed KV capacity
CSA ratio-4 compressed KV needs enough slots for the full target capacity:
So CSA
cmp_kv/cmp_block_tableshould support at leastCMP_MAX_BLOCKS = 32.For the immediate 8k+512 target, about
8704 / 4 = 2176compressed slots are needed, but using 32 blocks keeps the contract aligned withMAX_SEQ_LEN = 16384.HCA ratio-128 needs far fewer compressed slots for this target, but its sparse-attention and block-table semantics should still work at high absolute positions.
4. Synchronize indexer cache and score capacity with MAX_SEQ_LEN
CSA indexer capacity at 16k should be:
The existing
IDX_CACHE_MAX_BLOCKS = 64provides 8192 rows and should be physically large enough, but the following must be audited and updated:SCORE_LENscore [B, S, SCORE_LEN]topk_idxs [B, S, SCORE_LEN]SCORE_LEN = 4096IDX_TOPK = 512kv_seq_lensIf the current indexer sort pipeline assumes
SCORE_LEN = 2048, it needs a 4096-capable path.5. Implement full compressor state paging semantics
This is the most important contract gap.
We should not store full 16k raw compressor state physically. Instead, the natural serving contract is:
The current implementation tends to use one
COMPRESS_STATE_MAX_BLOCKSvalue for both:For 8k/16k decode, these should be decoupled.
Suggested contract:
Then:
This applies to:
compress_state_block_tablestate_slot_mappingcompress_state_block_tablestate_slot_mappinginner_compress_state_block_tableinner_state_slot_mappingThe compressor only needs to read the compression boundary window that is relevant for the current decode step, but it must be able to address that window by high absolute positions such as 8192 or 8703.
6. Ensure sparse-attention raw indices map correctly at long context
The decode sparse-attention raw-index contract is currently roughly:
For 8k+512 decode we need to ensure:
ori_block_tablestill manages only the sliding-window ring KV.cmp_block_tablecovers compressed slots up to 4095.idx_block_tablecovers indexer compressed slots up to 4095.position_idsandkv_seq_lenscorrectly limit visible compressed slots.7. Use repeated
decode_fwdcalls firstThe first milestone should be a host/Python loop:
A monolithic 512-step JIT graph is explicitly out of scope for this milestone.
8. Add final logits/sampling later
The first milestone can validate hidden output and cache/state side effects. The final closed-loop generation path can be added later:
Proposed Implementation Phases
Phase 1: Capacity constants
MAX_SEQ_LENcapacity to 16384.CMP_MAX_BLOCKS = 32.IDX_KV_LEN,SCORE_LEN, and all derived static shapes.Phase 2: Long-position metadata lowering
Add or update metadata generation for high absolute positions:
position_idskv_seq_lensori_slot_mappingcmp_slot_mappingidx_slot_mappingstate_slot_mappinginner_state_slot_mappingori_block_tablecmp_block_tableidx_block_tablecompress_state_block_tableinner_compress_state_block_tablePhase 3: Compressor state paging
Decouple logical state-table width from physical state-cache size.
Implement and test:
Phase 4: Single-token commit mode
Support one committed token per decode invocation, either with
S=1or withS=2plus inactive metadata for the second token.Phase 5: 512-step decode driver
Add a test/driver that repeatedly calls
decode_fwd512 times while carrying cache/state tensors across calls.Required Test Matrix
Basic position-capacity tests
start_pos = 0start_pos = 127start_pos = 128start_pos = 8191start_pos = 8192start_pos = 8703start_pos = 16383All should compile/run without RoPE, position, cache, or state mapping OOB.
CSA ratio-4 compressed boundary tests
start_pos = 3 -> 4: first compressed writestart_pos = 127 -> 128: compressed slot 31/32 boundarystart_pos = 511 -> 512: compressed KV block boundarystart_pos = 4095 -> 4096: old 1024-slot boundarystart_pos = 8191 -> 8192start_pos = 8703start_pos = 16383HCA ratio-128 compressed boundary tests
start_pos = 127 -> 128: first ratio-128 compressed writestart_pos = 255 -> 256start_pos = 8191 -> 8192start_pos = 8703start_pos = 16383Indexer long-context tests
kv_seq_len = 8192kv_seq_len = 8704kv_seq_len = 16384Required checks:
IDX_TOPK = 512still returns only top 512 slotsState paging tests
state_slot_mapping != absolute_posstate_slot_mapping = -1no-write tokenSparse-attention raw-index tests
[0, WIN)[WIN, WIN + S)[WIN + S, ...)cmp_block_tablestill worksidx_block_tablestill works512-step decode driver tests
Non-goals For The First Milestone
The first milestone is correctness and metadata-contract completeness for long-position decode.