Summary
Complete the DeepSeek V4 MTP execution contract in pypto-lib, covering both kernel-side MTP computation and serving-facing prefill/decode state construction.
This feature should make the remaining MTP work reviewable as one coherent feature:
- define the exact DeepSeek V4 MTP projection shape contract;
- define the MTP decoder-layer and logits boundaries;
- define prefill/decode input construction helpers;
- define speculative-token state updates;
- provide one independently runnable full-chain validation case.
Motivation / Use Case
DeepSeek V4 MTP requires more than a standalone projection kernel. The full path needs a clear contract for how target-model hidden states, MTP input tokens, decoder-layer execution, logits computation, and speculative-token state are connected.
The current gap is that individual pieces can be implemented or validated separately, but the repository still needs a complete, reviewable MTP execution contract that explains:
- what hidden state should be passed into MTP;
- how MTP projection handles HC lanes;
- where the MTP forward boundary ends;
- where logits are computed;
- how prefill and decode construct MTP inputs;
- how speculative tokens are verified and updated.
Without this contract, it is easy to accidentally use post-hc_head hidden states, treat lane-aware MTP projection as a normal [T, D] projection, or mix forward and logits behavior into one unclear boundary.
Proposed API / Behavior
Add the remaining DeepSeek V4 MTP implementation and validation pieces under models/deepseek/v4.
Required files:
models/deepseek/v4/mtp.py
models/deepseek/v4/mtp_inputs.py
Required MTP kernel/validation functions:
mtp_projection_impl
mtp_seed_hc_stack
mtp_decoder_layer_tail
mtp_forward_tail
mtp_compute_logits_tail
mtp_full_chain
Required serving-facing helpers:
build_mtp_prefill_input_ids
restore_cp_prefill_next_tokens
build_mtp_prefill_input_ids_cp
build_main_decode_input_ids_for_mtp
build_mtp_decode_input_ids
verify_mtp_spec_tokens
update_mtp_state_after_step
Expected MTP projection behavior:
previous_hidden_states
-> reshape to [T, HC_MULT, D]
-> hnorm
-> h_proj
input embedding / current hidden state
-> mask positions == 0
-> enorm
-> e_proj
-> unsqueeze/broadcast to HC lane dimension
projected_mtp_hidden = h_branch + e_branch
projected_mtp_hidden: [T, HC_MULT, D]
Expected MTP decoder-layer boundary:
projected_mtp_hidden [T, HC_MULT, D]
-> MHC pre / attention norm boundary
-> SWA attention
-> MHC post/pre / FFN norm boundary
-> MoE/FFN
-> MHC post
-> pre_hc_residual [T, HC_MULT * D]
Expected logits boundary:
pre_hc_residual
-> hc_head
-> shared_head_norm
-> lm_head logits
MTP forward should return pre_hc_residual [T, HC_MULT * D]. Logits should be computed by a separate logits tail.
For multi-step MTP, the same step index must select the same MTP layer for forward and logits. Each proposal step should feed the previous step's flat pre-hc_head residual back as the next previous_hidden_states.
Only one runnable validation case is required:
The full-chain case should cover:
prefill/decode input construction
-> lane-aware projection
-> decoder-layer tail
-> logits tail
-> candidate logits / speculative token state
Alternatives Considered
One alternative is to only validate the projection kernel first and leave decoder tail, logits tail, and serving-state helpers for later. This makes the first change smaller, but it does not provide a complete MTP execution contract and makes review harder because later code has to infer the missing boundaries.
Another alternative is to put all MTP behavior directly into serving code. That would make the runtime path less reusable and harder to validate independently in pypto-lib.
The proposed approach keeps kernel-side MTP computation and CPU-validatable serving-state helpers separated, while still providing one full-chain validation case to verify the complete path.
Additional Context
Important model contract details:
- MTP must consume the target model's pre-
hc_head flat residual, not post-hc_head or final normalized hidden states.
previous_hidden_states should support [T, HC_MULT * D] or [T, HC_MULT, D].
- MTP projection is lane-aware and should produce
[T, HC_MULT, D].
- The
positions == 0 embedding mask is part of the MTP input contract.
- MTP forward and MTP logits should remain separate boundaries.
- The decode proposal loop must recycle the previous step's flat pre-
hc_head residual.
- Partitioned prefill token restore and packed input construction should be covered.
- Only one full-chain validation case is required for this issue.
Summary
Complete the DeepSeek V4 MTP execution contract in
pypto-lib, covering both kernel-side MTP computation and serving-facing prefill/decode state construction.This feature should make the remaining MTP work reviewable as one coherent feature:
Motivation / Use Case
DeepSeek V4 MTP requires more than a standalone projection kernel. The full path needs a clear contract for how target-model hidden states, MTP input tokens, decoder-layer execution, logits computation, and speculative-token state are connected.
The current gap is that individual pieces can be implemented or validated separately, but the repository still needs a complete, reviewable MTP execution contract that explains:
Without this contract, it is easy to accidentally use post-
hc_headhidden states, treat lane-aware MTP projection as a normal[T, D]projection, or mix forward and logits behavior into one unclear boundary.Proposed API / Behavior
Add the remaining DeepSeek V4 MTP implementation and validation pieces under
models/deepseek/v4.Required files:
models/deepseek/v4/mtp.pymodels/deepseek/v4/mtp_inputs.pyRequired MTP kernel/validation functions:
mtp_projection_implmtp_seed_hc_stackmtp_decoder_layer_tailmtp_forward_tailmtp_compute_logits_tailmtp_full_chainRequired serving-facing helpers:
build_mtp_prefill_input_idsrestore_cp_prefill_next_tokensbuild_mtp_prefill_input_ids_cpbuild_main_decode_input_ids_for_mtpbuild_mtp_decode_input_idsverify_mtp_spec_tokensupdate_mtp_state_after_stepExpected MTP projection behavior:
Expected MTP decoder-layer boundary:
Expected logits boundary:
MTP forward should return
pre_hc_residual [T, HC_MULT * D]. Logits should be computed by a separate logits tail.For multi-step MTP, the same step index must select the same MTP layer for forward and logits. Each proposal step should feed the previous step's flat pre-
hc_headresidual back as the nextprevious_hidden_states.Only one runnable validation case is required:
The
full-chaincase should cover:Alternatives Considered
One alternative is to only validate the projection kernel first and leave decoder tail, logits tail, and serving-state helpers for later. This makes the first change smaller, but it does not provide a complete MTP execution contract and makes review harder because later code has to infer the missing boundaries.
Another alternative is to put all MTP behavior directly into serving code. That would make the runtime path less reusable and harder to validate independently in
pypto-lib.The proposed approach keeps kernel-side MTP computation and CPU-validatable serving-state helpers separated, while still providing one full-chain validation case to verify the complete path.
Additional Context
Important model contract details:
hc_headflat residual, not post-hc_heador final normalized hidden states.previous_hidden_statesshould support[T, HC_MULT * D]or[T, HC_MULT, D].[T, HC_MULT, D].positions == 0embedding mask is part of the MTP input contract.hc_headresidual.