feat(dsv4): single-token decode config (B=8, S=1, T=8)#638
feat(dsv4): single-token decode config (B=8, S=1, T=8)#638zhangqi-chen wants to merge 2 commits into
Conversation
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request updates the DeepSeek v4 model configuration and kernels to support single-token decoding by setting DECODE_SEQ = 1 and reducing several tile sizes from 128 to 64. It also adapts the sliding window attention (decode_sparse_attn_swa.py) to handle overlay sizes T smaller than ATTN_K_TILE by zero-padding the overlay. Feedback on these changes suggests optimizing the padding logic to conditionally execute only when T < ATTN_K_TILE, avoiding unnecessary memory bandwidth and SPMD launch overhead on the main production path where T == ATTN_K_TILE.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
9734429 to
bb43b9f
Compare
Switch the decode config from MTP (B=4, S=2) to single-token decode (B=8, S=1), keeping T = B*S = 8 so the MoE pipeline width is unchanged. The kernel-side enablement for small/single-token decode already lives on main: qkv_proj_rope pads T to MATMUL_T_TILE=16, decode_sparse_attn_swa zero-pads the MTP overlay up to ATTN_K_TILE, and hw-native-sys#647 drives the MoE pipeline natively at T = MOE_TOKENS (prefill entries override to PREFILL_TOKENS). So this is now a pure config flip. Validated a2a3 ep2: decode_fwd (T=8) and prefill_fwd (T=128) run end-to-end.
bb43b9f to
2a4828d
Compare
… MoE ring-heap config: PREFILL_RECV_MAX=1024 -- real gate routing skews past the RECV_SAFETY=4 uniform bound, overflowing the default recv depth and deadlocking dispatch/combine on real-weight ep8 prefill. moe: mark `moe` @pl.jit.inline(auto_scope=False) and use a plain pl.scope() around expert_routed + combine + hc_post so the compiler places AUTO runtime scopes across the whole MoE instead of one hand-placed MANUAL scope. The auto placement recycles the large recv buffers per stage and spreads the MoE over two rings, dropping the worst-ring MoE ring-heap high-water from 126% (unscoped) / ~90% (single MANUAL scope) to ~71% (keeps the larger RECV_MAX within the 1G ring).
2a4828d to
d184688
Compare
## Summary Splits the **MoE auto-scope** change out of #638 so it can land independently of that PR's decode config flip (`B=8, S=1`) and prefill `RECV_MAX` sizing. Mark `moe` `@pl.jit.inline(auto_scope=False)` and wrap `expert_routed` + `combine` + `hc_post` in a plain `pl.scope()` so the compiler places AUTO runtime scopes across the whole MoE instead of one hand-placed `pl.scope(mode=MANUAL)`. The auto placement recycles the large recv buffers per stage and spreads the MoE across two rings, dropping the **worst-ring** MoE ring-heap high-water from **126%** (unscoped) / **~90%** (single MANUAL scope) to **~71%** of the 1G ring (measured via `scope_stats`). This keeps a larger `RECV_MAX` within the ring budget. ## Changes - **moe.py**: `moe` → `@pl.jit.inline(auto_scope=False)`; wrap `expert_routed` + `combine` + `hc_post` in `pl.scope()`. ## Validation (a2a3) As validated in #638: `prefill_fwd.py` ep2 re-run after the auto-scope change PASSes, `scope_stats` worst-ring heap high-water 90.6% → 71.2%, all rings `fatal=False / dropped=0`. Real-weight ep8 prefill (cards 8-15) runs end-to-end with the default 1G ring heap.
Summary
Switches the DeepSeek-V4 Flash decode config from MTP (
B=4, S=2) to single-token decode (B=8, S=1), keepingT = B*S = 8so the MoE pipeline width is unchanged.The kernel-side enablement for small/single-token decode already lives on main —
qkv_proj_ropepadsTtoMATMUL_T_TILE=16,decode_sparse_attn_swazero-pads the MTP overlay up toATTN_K_TILE, and #647 drives the MoE pipeline natively atT = MOE_TOKENS(prefill entries override toPREFILL_TOKENS). So the decode switch is now a pure config flip.Also sizes the prefill MoE recv depth and auto-scopes the MoE so the real-weight ep8 prefill path runs end-to-end.
Changes
DECODE_BATCH/SEQ→B=8, S=1(T=8). Pure config flip; the kernel enablement is already on main.PREFILL_RECV_MAX = 1024. With real gate weights the routing skews well past theRECV_SAFETY=4uniform bound, so the formula-derived depth (96) overflows and deadlocks dispatch/combine on real-weight ep8 prefill. A bisection on the real-weight ep8 path puts the minimal passing depth in(768, 1024].moe@pl.jit.inline(auto_scope=False)and wrapexpert_routed+combine+hc_postin a plainpl.scope()so the compiler places AUTO runtime scopes across the whole MoE instead of one hand-placedpl.scope(mode=MANUAL). The auto placement recycles the large recv buffers per stage and spreads the MoE across two rings, dropping the worst-ring MoE ring-heap high-water from 126% (unscoped) / ~90% (single MANUAL scope) to ~71% of the 1G ring (measured viascope_stats), keeping the largerRECV_MAXwithin budget.Validation (a2a3)
decode_fwd.py/prefill_fwd.pyep2 (T=8 / T=128): run end-to-end.decode_fwd.py/prefill_fwd.pyep8 (cards 8–15), real W8A8 weights: run end-to-end. Real-weight prefill needsRECV_MAX=1024(else dispatch/combine deadlock); the default 1G ring heap suffices with the MoE auto-scope.prefill_fwd.pyep2 re-run after the auto-scope change: PASS,scope_statsworst-ring heap high-water 90.6% → 71.2%, all ringsfatal=False / dropped=0.Notes
decode_fwd/prefill_fwdhave no built-in golden, so their "PASS" = ran-to-completion.sparse_blk_*__phi_v7 not found in MLIR mappinglogs from the prefill sparse-attn codegen (pre-existing, unrelated to this change); compile + runtime both complete.