Skip to content

[Feature] Complete DeepSeek V4 FLASH main-model operator path through lm_head #460

Description

@high-cloud

Summary

Complete the DeepSeek V4 FLASH main-model operator path end to end, from input token ids through embedding, the full main-model layer stack, lm_head, full logits gather across TP shards, and token output.

The target deployment concept follows the CANN DeepSeek V4 layout:

  • attention layers use data parallelism (DP);
  • MoE layers use expert parallelism (EP);
  • lm_head uses tensor parallelism (TP).

These are not multiplied into a DP x EP rank grid. They are layer/module-level parallel modes over the available rank set. The first end-to-end acceptance target should include attention DP=16, MoE EP=16, and lm_head TP=16. A one-logical-attention-DP-shard mode can remain as a debugging subpath, but it is not sufficient for the first end-to-end acceptance target.

MTP and DeepSeek V4 Pro are out of scope for this issue. Target FLASH only.

Motivation / Use Case

The repository already has substantial DeepSeek V4 single-kernel and stage-level coverage under models/deepseek/v4/, including SWA/CSA/HCA attention paths, token-major HC/MoE helpers, a prefill MoE wrapper, a 2-rank moe_ep.py bring-up, and lm_head TP-shard compute.

Those pieces are not yet sufficient for FLASH main-model end-to-end validation:

  • moe_ep.py is still a 2-rank DEMO distributed bring-up, not an EP=16 FLASH runtime path.
  • The current graph does not run all FLASH layers from token ids through embedding and final lm_head.
  • Real FLASH model weight loading and checkpoint-to-operator tensor mapping are not yet part of the harness.
  • Metadata construction is still scattered across fixed-shape stage wrappers instead of being a first-class serving-contract builder.
  • Some prefill paths still assume fixed short shapes, one window page, or PREFILL_BATCH * PREFILL_SEQ == DECODE_BATCH * DECODE_SEQ.
  • LM_HEAD_TP_SIZE is currently 8 in config.py, but the target FLASH lm_head TP size should be 16.
  • Local TP-shard validation is not enough for token output. With lmhead_tp_size = 16, each rank only owns one vocabulary shard; a real token decision needs gathered logits across all 16 TP shards.

This issue tracks the operator-side work needed to produce end-to-end token output for the main FLASH model. A short natural-language prompt such as 介绍下北京故宫 is enough for the first token-output validation, but the metadata builder should support arbitrary variable-length prompts so coverage can expand later.

Proposed API / Behavior

Required contract surface

Add a contract-driven metadata builder under models/deepseek/v4/. It should construct and validate the metadata that a serving layer would provide, including:

  • token-major prefill/decode views;
  • per-request query lengths;
  • cumulative query offsets, equivalent to query_start_loc / cu_q_lens;
  • per-request context lengths;
  • absolute start positions for RoPE and compressor/indexer state updates;
  • token-to-request and token-to-cache-slot mapping;
  • original KV block tables;
  • compressed KV block tables;
  • current-token KV write slots;
  • sliding-window sparse indices;
  • compressed sparse top-k indices;
  • padded/invalid sparse entries for fixed-shape contracts;
  • MoE global expert ids, EP-local expert mapping, dispatch counts, route-to-token mapping, and combine metadata;
  • attention DP rank or per-shard invocation contract;
  • MoE EP rank and EP group membership;
  • lm_head TP rank, vocabulary shard range, and logits-gather contract.

Main-model graph

Build a FLASH main-model prefill/decode harness that runs:

  1. input token ids;
  2. real weight-backed embedding;
  3. all FLASH layers using the configured compress_ratios schedule;
  4. per-layer HC pre, attention path selection, HC post, MoE, and state handoff;
  5. final hidden states into lm_head;
  6. TP=16 full logits gather across all vocabulary shards;
  7. greedy token selection for the first validation path.

Model weights

Add minimal FLASH model weight loading from a configured local model path:

  • map checkpoint tensor names to models/deepseek/v4/ operator inputs;
  • apply expected quantized-weight layouts and scale tensors;
  • shard or replicate weights according to module-local parallel mode:
    • attention weights replicated across attention DP=16 ranks;
    • local MoE experts for EP=16;
    • lm_head vocabulary shards for TP=16;
    • shared/replicated weights where required;
  • add a smoke check that verifies required tensors, shapes, and dtypes before NPU execution.

Attention DP=16

Bring up the attention-layer data-parallel path as part of the first end-to-end target:

  • run the FLASH attention layers with attention DP=16 across the same 16-rank set;
  • construct per-DP-rank token-major metadata, including local query lengths, cumulative query offsets, context lengths, cache block tables, sparse indices, and KV write slots;
  • replicate or shard inputs according to the attention-DP contract, with attention weights replicated as required;
  • define the handoff from attention DP local token rows into the MoE EP=16 routing group;
  • preserve a one-logical-DP-shard debug mode only as an intermediate diagnostic path, not as acceptance.

EP=16 MoE

Convert the current distributed MoE bring-up into a FLASH EP=16 path:

  • remove the DEMO-only override from moe_ep.py;
  • support N_RANKS = 16;
  • route over 256 global experts;
  • chunk gate computation over the expert dimension so FLASH routing fits;
  • scale dispatch/combine windows, barriers, counts, and route buffers to EP=16;
  • validate distributed runtime, not just compile-only.

Prefill and metadata generality

Support arbitrary variable-length prompts at the metadata/orchestration level, but use short prompts for the first precision and token-output validation matrix. The first natural-language validation prompt can be 介绍下北京故宫.

The prefill orchestration should eventually handle prompts longer than one compression/window chunk and support multiple original/compressed pages. Internal kernels may remain fixed-shape where necessary if chunked orchestration covers the variable-length contract.

lm_head TP=16 and logits gather

lm_head must be included in this issue because end-to-end validation needs token output.

Required behavior:

  • set LM_HEAD_TP_SIZE / target lmhead_tp_size to 16;
  • load and shard lm_head.weight for TP=16;
  • implement full logits gather across all 16 TP vocabulary shards;
  • run token selection on gathered global logits, not a local shard;
  • initially use greedy argmax if that keeps validation simpler than production sampling.

Alternatives Considered

  • Validate only local lm_head TP shards. This is insufficient for end-to-end token output because it cannot choose the global vocabulary token.
  • Start from synthetic hidden states instead of token ids. This is useful for kernel bring-up, but not enough for end-to-end token-output validation; the first full graph should include embedding.
  • Accept only a one-logical-attention-DP-shard path first. This is useful for debugging but is no longer sufficient for the first end-to-end acceptance target because the target deployment requires attention DP=16.
  • Target Pro at the same time. This issue should target FLASH only first.
  • Treat MTP as part of the first end-to-end target. MTP remains out of scope for this issue.

Additional Context

Initial acceptance criteria:

  • Target model is DeepSeek V4 FLASH only.
  • MTP is excluded.
  • Attention layers run with DP=16 in the first end-to-end acceptance target.
  • The metadata builder lives under models/deepseek/v4/ and supports arbitrary variable-length prompts.
  • The first validation matrix may use short prompts such as 介绍下北京故宫.
  • Real FLASH weights can be loaded and mapped to all required operator inputs.
  • The graph starts from token ids and includes embedding.
  • Attention DP=16 local token rows hand off correctly into MoE EP=16 dispatch/combine under real distributed runtime.
  • The full FLASH main-model layer stack runs through lm_head.
  • lm_head uses TP=16 and token selection is based on gathered global logits across all 16 TP shards.
  • The harness emits logits and at least one token output.
  • The validated shape matrix records prompt lengths, batch size, compression-ratio coverage, attention DP size, EP size, lm_head TP size, and the DP-to-EP handoff mode.

Suggested milestones:

  1. Add the models/deepseek/v4/ metadata builder and shape checks.
  2. Add minimal FLASH weight loading and tensor-name/shape validation.
  3. Build a FLASH main-model harness that starts from token ids, runs embedding, and supports attention DP=16 metadata.
  4. Convert the 2-rank moe_ep.py bring-up into a parameterized EP harness and get EP=16 compile working.
  5. Fix gate expert-dimension chunking for FLASH global routing.
  6. Run EP=16 MoE distributed runtime on real NPU.
  7. Validate the attention-DP16 to MoE-EP16 rank-mode handoff in the main-model harness.
  8. Set lm_head TP to 16, gather logits across all 16 TP shards, and emit token output.
  9. Expand prompt-length coverage toward arbitrary variable-length prompts.

Related: #378, #382, #410, #456

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions