Skip to content

[Feature] Support DeepSeek V4 decode for 8k context + 512 decode steps #657

Description

@sjduan

Summary

We need to extend the DeepSeek V4 decode path to support the following serving target:

  • Up to 8k prompt/input context.
  • Then generate 512 output tokens.
  • For now, treat this as 512 separate decode_fwd invocations. Although the current static decode shape has DECODE_SEQ = 2 for MTP, the upper layer cannot consume the second MTP token yet, so the first milestone should only commit one accepted token per decode call.
  • The maximum absolute position that must be valid is therefore:
8192 + 512 - 1 = 8703

To avoid repeatedly resizing around this boundary, the proposed static capacity target is:

MAX_SEQ_LEN = 16384

This issue is not asking for 16k full raw KV. Decode should continue to keep only sliding-window raw KV. The long-context requirement is mainly about compressed history, indexer history, RoPE capacity, and state/cache mapping correctness for high absolute positions.

Current Code Path

The current decode full-forward path is roughly:

models/deepseek/v4/decode_fwd.py
  -> layer 0/1: SWA attention + MoE
  -> layer 2/4/.../42: CSA attention + MoE
  -> layer 3/5/.../41: HCA attention + MoE
  -> hc_head

Each attention module is roughly:

hc_pre
  -> rmsnorm
  -> qkv_proj_rope
  -> path-specific compressor/indexer/topk/sparse_attn
  -> delayed KV/cache writeback
  -> hc_post

Relevant files include:

  • models/deepseek/v4/config.py
  • models/deepseek/v4/decode_fwd.py
  • models/deepseek/v4/decode_attention_swa.py
  • models/deepseek/v4/decode_attention_hca.py
  • models/deepseek/v4/decode_attention_csa.py
  • models/deepseek/v4/decode_sparse_attn.py
  • models/deepseek/v4/decode_sparse_attn_hca.py
  • models/deepseek/v4/decode_indexer.py
  • models/deepseek/v4/decode_indexer_compressor.py
  • models/deepseek/v4/decode_compressor_ratio4.py
  • models/deepseek/v4/decode_compressor_ratio128.py

Required Capabilities

1. Increase decode position capacity to 16384

FLASH.max_position_embeddings and all decode-side shapes derived from MAX_SEQ_LEN should support at least 16384 positions.

This needs to cover:

  • freqs_cos / freqs_sin
  • position_ids
  • IDX_KV_LEN = MAX_SEQ_LEN // 4
  • SCORE_LEN
  • sparse-attention topk width and padded width
  • compressor state block-table width
  • golden/fixture high-position cases

2. Support one committed token per decode call

The first milestone should support 512 separate decode calls, each committing exactly one accepted token.

There are two possible approaches:

  1. Add/support a DECODE_SEQ = 1 decode configuration.
  2. Keep the static S = 2 shape, but make the second token a non-committed placeholder via metadata:
ori_slot_mapping[second_token] = -1
cmp_slot_mapping[second_token] = -1
idx_slot_mapping[second_token] = -1
state_slot_mapping[second_token] = -1
inner_state_slot_mapping[second_token] = -1

The second token must not mutate raw KV, compressed KV, indexer KV, or compressor state. Sparse-attention overlay/topk visibility must also respect the single committed token semantics.

3. Increase CSA compressed KV capacity

CSA ratio-4 compressed KV needs enough slots for the full target capacity:

compressed_slots = 16384 / 4 = 4096
CMP_MAX_BLOCKS = 4096 / 128 = 32

So CSA cmp_kv / cmp_block_table should support at least CMP_MAX_BLOCKS = 32.

For the immediate 8k+512 target, about 8704 / 4 = 2176 compressed slots are needed, but using 32 blocks keeps the contract aligned with MAX_SEQ_LEN = 16384.

HCA ratio-128 needs far fewer compressed slots for this target, but its sparse-attention and block-table semantics should still work at high absolute positions.

4. Synchronize indexer cache and score capacity with MAX_SEQ_LEN

CSA indexer capacity at 16k should be:

IDX_KV_LEN = 16384 / 4 = 4096

The existing IDX_CACHE_MAX_BLOCKS = 64 provides 8192 rows and should be physically large enough, but the following must be audited and updated:

  • SCORE_LEN
  • score [B, S, SCORE_LEN]
  • topk_idxs [B, S, SCORE_LEN]
  • sort / merge-sort stages for SCORE_LEN = 4096
  • topk selection with IDX_TOPK = 512
  • visibility clamping by kv_seq_lens

If the current indexer sort pipeline assumes SCORE_LEN = 2048, it needs a 4096-capable path.

5. Implement full compressor state paging semantics

This is the most important contract gap.

We should not store full 16k raw compressor state physically. Instead, the natural serving contract is:

logical state block table covers absolute-position space
physical state cache is a small rolling/paged pool
state_slot_mapping points each current token to its physical state row

The current implementation tends to use one COMPRESS_STATE_MAX_BLOCKS value for both:

  • logical block-table width, and
  • physical state tensor size.

For 8k/16k decode, these should be decoupled.

Suggested contract:

STATE_TABLE_MAX_BLOCKS = ceil(MAX_SEQ_LEN / STATE_BLOCK_SIZE)
STATE_PHYSICAL_BLOCKS  = enough for the active rolling compressor window

Then:

compress_state_block_table[b, logical_block] -> physical_block or -1
state_slot_mapping[t] -> physical state row for the current token, or -1

This applies to:

  • HCA ratio-128 main compressor
    • compress_state_block_table
    • state_slot_mapping
  • CSA ratio-4 main compressor
    • compress_state_block_table
    • state_slot_mapping
  • CSA inner indexer compressor
    • inner_compress_state_block_table
    • inner_state_slot_mapping

The compressor only needs to read the compression boundary window that is relevant for the current decode step, but it must be able to address that window by high absolute positions such as 8192 or 8703.

6. Ensure sparse-attention raw indices map correctly at long context

The decode sparse-attention raw-index contract is currently roughly:

[0, WIN)              historical sliding-window ring KV
[WIN, WIN + S)        current MTP overlay KV
[WIN + S, ...)        compressed KV slot

For 8k+512 decode we need to ensure:

  • ori_block_table still manages only the sliding-window ring KV.
  • cmp_block_table covers compressed slots up to 4095.
  • idx_block_table covers indexer compressed slots up to 4095.
  • compressed raw indices are slot-based, not confused with absolute positions.
  • position_ids and kv_seq_lens correctly limit visible compressed slots.
  • future compressed slots are never visible.
  • future/current overlay tokens do not incorrectly overwrite historical ring slots inside the same decode step.

7. Use repeated decode_fwd calls first

The first milestone should be a host/Python loop:

for step in range(512):
    prepare metadata for one accepted token
    run decode_fwd once
    carry mutated caches/states into the next step

A monolithic 512-step JIT graph is explicitly out of scope for this milestone.

8. Add final logits/sampling later

The first milestone can validate hidden output and cache/state side effects. The final closed-loop generation path can be added later:

final RMSNorm -> lm_head -> logits -> argmax/sampling -> next input_ids

Proposed Implementation Phases

Phase 1: Capacity constants

  • Set decode MAX_SEQ_LEN capacity to 16384.
  • Set CSA compressed KV capacity to CMP_MAX_BLOCKS = 32.
  • Update IDX_KV_LEN, SCORE_LEN, and all derived static shapes.
  • Keep raw KV sliding-window-only.

Phase 2: Long-position metadata lowering

Add or update metadata generation for high absolute positions:

  • position_ids
  • kv_seq_lens
  • ori_slot_mapping
  • cmp_slot_mapping
  • idx_slot_mapping
  • state_slot_mapping
  • inner_state_slot_mapping
  • ori_block_table
  • cmp_block_table
  • idx_block_table
  • compress_state_block_table
  • inner_compress_state_block_table

Phase 3: Compressor state paging

Decouple logical state-table width from physical state-cache size.

Implement and test:

  • ratio-128 state paging
  • ratio-4 main state paging
  • ratio-4 inner indexer state paging

Phase 4: Single-token commit mode

Support one committed token per decode invocation, either with S=1 or with S=2 plus inactive metadata for the second token.

Phase 5: 512-step decode driver

Add a test/driver that repeatedly calls decode_fwd 512 times while carrying cache/state tensors across calls.

Required Test Matrix

Basic position-capacity tests

  • start_pos = 0
  • start_pos = 127
  • start_pos = 128
  • start_pos = 8191
  • start_pos = 8192
  • start_pos = 8703
  • start_pos = 16383

All should compile/run without RoPE, position, cache, or state mapping OOB.

CSA ratio-4 compressed boundary tests

  • start_pos = 3 -> 4: first compressed write
  • start_pos = 127 -> 128: compressed slot 31/32 boundary
  • start_pos = 511 -> 512: compressed KV block boundary
  • start_pos = 4095 -> 4096: old 1024-slot boundary
  • start_pos = 8191 -> 8192
  • start_pos = 8703
  • start_pos = 16383

HCA ratio-128 compressed boundary tests

  • start_pos = 127 -> 128: first ratio-128 compressed write
  • start_pos = 255 -> 256
  • start_pos = 8191 -> 8192
  • start_pos = 8703
  • start_pos = 16383

Indexer long-context tests

  • kv_seq_len = 8192
  • kv_seq_len = 8704
  • kv_seq_len = 16384

Required checks:

  • visible compressed slots are computed correctly
  • score/topk buffers do not OOB
  • IDX_TOPK = 512 still returns only top 512 slots
  • compressed slots beyond the old 2048-score range are valid

State paging tests

  • non-identity state block table
  • ring-reused physical state blocks
  • state_slot_mapping != absolute_pos
  • state_slot_mapping = -1 no-write token
  • ratio-4 main state block-size boundary
  • ratio-4 inner state block-size boundary
  • ratio-128 state block-size boundary
  • high-position state reads at 8192, 8704, and 16383

Sparse-attention raw-index tests

  • historical ring raw index [0, WIN)
  • current overlay raw index [WIN, WIN + S)
  • compressed raw index [WIN + S, ...)
  • future overlay is not visible
  • future compressed slot is not visible
  • compressed slot > 1024 is readable
  • permuted cmp_block_table still works
  • permuted idx_block_table still works

512-step decode driver tests

  • single request, prompt length 8192, run 512 decode calls
  • mixed batch with different start positions per request
  • requests near different boundary classes:
    • 127/128
    • 511/512
    • 4095/4096
    • 8191/8192
    • 8703
  • after 512 calls, cache/state metadata should remain valid and deterministic
  • each call commits exactly one token

Non-goals For The First Milestone

  • full raw KV for 8k/16k context
  • monolithic 512-step JIT graph
  • fully dynamic serving allocator
  • final logits/sampling loop
  • performance optimization

The first milestone is correctness and metadata-contract completeness for long-position decode.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Fields

No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions