Skip to content

[Feature] Complete DeepSeek V4 MTP prefill/decode execution contract #678

Description

@hashiqiqixian

Summary

Complete the DeepSeek V4 MTP execution contract in pypto-lib, covering both kernel-side MTP computation and serving-facing prefill/decode state construction.

This feature should make the remaining MTP work reviewable as one coherent feature:

  1. define the exact DeepSeek V4 MTP projection shape contract;
  2. define the MTP decoder-layer and logits boundaries;
  3. define prefill/decode input construction helpers;
  4. define speculative-token state updates;
  5. provide one independently runnable full-chain validation case.

Motivation / Use Case

DeepSeek V4 MTP requires more than a standalone projection kernel. The full path needs a clear contract for how target-model hidden states, MTP input tokens, decoder-layer execution, logits computation, and speculative-token state are connected.

The current gap is that individual pieces can be implemented or validated separately, but the repository still needs a complete, reviewable MTP execution contract that explains:

  1. what hidden state should be passed into MTP;
  2. how MTP projection handles HC lanes;
  3. where the MTP forward boundary ends;
  4. where logits are computed;
  5. how prefill and decode construct MTP inputs;
  6. how speculative tokens are verified and updated.

Without this contract, it is easy to accidentally use post-hc_head hidden states, treat lane-aware MTP projection as a normal [T, D] projection, or mix forward and logits behavior into one unclear boundary.

Proposed API / Behavior

Add the remaining DeepSeek V4 MTP implementation and validation pieces under models/deepseek/v4.

Required files:

  • models/deepseek/v4/mtp.py
  • models/deepseek/v4/mtp_inputs.py

Required MTP kernel/validation functions:

  • mtp_projection_impl
  • mtp_seed_hc_stack
  • mtp_decoder_layer_tail
  • mtp_forward_tail
  • mtp_compute_logits_tail
  • mtp_full_chain

Required serving-facing helpers:

  • build_mtp_prefill_input_ids
  • restore_cp_prefill_next_tokens
  • build_mtp_prefill_input_ids_cp
  • build_main_decode_input_ids_for_mtp
  • build_mtp_decode_input_ids
  • verify_mtp_spec_tokens
  • update_mtp_state_after_step

Expected MTP projection behavior:

previous_hidden_states
    -> reshape to [T, HC_MULT, D]
    -> hnorm
    -> h_proj

input embedding / current hidden state
    -> mask positions == 0
    -> enorm
    -> e_proj
    -> unsqueeze/broadcast to HC lane dimension

projected_mtp_hidden = h_branch + e_branch
projected_mtp_hidden: [T, HC_MULT, D]

Expected MTP decoder-layer boundary:

projected_mtp_hidden [T, HC_MULT, D]
    -> MHC pre / attention norm boundary
    -> SWA attention
    -> MHC post/pre / FFN norm boundary
    -> MoE/FFN
    -> MHC post
    -> pre_hc_residual [T, HC_MULT * D]

Expected logits boundary:

pre_hc_residual
    -> hc_head
    -> shared_head_norm
    -> lm_head logits

MTP forward should return pre_hc_residual [T, HC_MULT * D]. Logits should be computed by a separate logits tail.

For multi-step MTP, the same step index must select the same MTP layer for forward and logits. Each proposal step should feed the previous step's flat pre-hc_head residual back as the next previous_hidden_states.

Only one runnable validation case is required:

full-chain

The full-chain case should cover:

prefill/decode input construction
    -> lane-aware projection
    -> decoder-layer tail
    -> logits tail
    -> candidate logits / speculative token state

Alternatives Considered

One alternative is to only validate the projection kernel first and leave decoder tail, logits tail, and serving-state helpers for later. This makes the first change smaller, but it does not provide a complete MTP execution contract and makes review harder because later code has to infer the missing boundaries.

Another alternative is to put all MTP behavior directly into serving code. That would make the runtime path less reusable and harder to validate independently in pypto-lib.

The proposed approach keeps kernel-side MTP computation and CPU-validatable serving-state helpers separated, while still providing one full-chain validation case to verify the complete path.

Additional Context

Important model contract details:

  • MTP must consume the target model's pre-hc_head flat residual, not post-hc_head or final normalized hidden states.
  • previous_hidden_states should support [T, HC_MULT * D] or [T, HC_MULT, D].
  • MTP projection is lane-aware and should produce [T, HC_MULT, D].
  • The positions == 0 embedding mask is part of the MTP input contract.
  • MTP forward and MTP logits should remain separate boundaries.
  • The decode proposal loop must recycle the previous step's flat pre-hc_head residual.
  • Partitioned prefill token restore and packed input construction should be covered.
  • Only one full-chain validation case is required for this issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions