diff --git a/docs/index.md b/docs/index.md index 4ab29f06..1650a968 100644 --- a/docs/index.md +++ b/docs/index.md @@ -43,9 +43,10 @@ Organized by domain (model line / subsystem / playbook / lesson) instead of by l | Path | TL;DR | | --- | --- | -| `models/qwen35/roadmap.md` | Qwen3.5-4B roadmap (2026-06 review): decode-tuning refresh improves direct TPOT by 2-3%, while vLLM still leads 1024/256 HTTP decode and high-concurrency throughput. Open items: HND prefill staging, prefix-cache design, serving concurrency. | +| `models/qwen35/roadmap.md` | Qwen3.5-4B roadmap: DFlash is opt-in and batched for multi-active decode, with RTX 5090 in-process c4/c8/c16 gains across decode-heavy, medium, and long prompts. Open items: HTTP pressure validation for DFlash, HND prefill staging, prefix-cache design, and broader feature coverage. | | `models/qwen35/kv-admission.md` | Issue #254 complete: Qwen3.5 now uses full-lifetime KV admission, deferred pressure handling, impossible-request rejection, explicit error semantics, direct rejection-event coverage, RTX 5090 e2e, and real HTTP pressure/post-pressure validation. | | `models/qwen35/optimization.md` | Hybrid 24 linear + 8 full attn optimization ledger. Decode-tuning refresh fuses MLP gate/up and tunes decode cublasLt buckets, improving direct TPOT by 2-3%; vLLM still leads 1024/256 HTTP decode. | +| `models/qwen35/dflash-speculative-decoding.md` | Qwen3.5 DFlash speculative decoding is opt-in and batched: hybrid-state transaction covers KV + recurrent + conv state, correctness gates pass, and RTX 5090 in-process c4/c8/c16 A/B improves `prompt_len=1` by `+16.7%/+15.4%/+14.0%`, `1024` by `+209.9%/+168.6%/+45.3%`, and `4096` by `+135.9%/+35.7%/+25.6%`. | | `models/qwen35/accuracy.md` | Qwen3.5-4B HF bf16 logits goldens through `past_key_values`: short replay covers sequential graph, bucket-straddling batched graph, and slot-compaction; long replay covers 4097/8192-token prompts; full GSM8K 8-shot now matches the HF baseline within 0.15 percentage points. | | `models/qwen35/model-crate.md` | `openinfer-qwen35-4b` owns Qwen3.5 model/scheduler/recurrent ops/tests/benches; feature-gated behind `qwen35-4b` (Triton AOT is the only Python build dependency); root loads it through `EngineHandle`. Build/check/clippy, root bench sanity check, historical Qwen3.5 e2e, and scheduler e2e records live here. | | `models/qwen35/kernel-plan.md` | Qwen3.5-4B has a `openinfer_qwen35_4b::kernel_plan()` static descriptor mirroring the qwen3 module — enumerates every prefill/decode/unified op with its Rust call site, backend, and notes, so you can dump the active kernel mix without reading call sites. Pure refactor (issue #256), no kernel behavior change. | diff --git a/docs/models/qwen35/dflash-speculative-decoding.md b/docs/models/qwen35/dflash-speculative-decoding.md new file mode 100644 index 00000000..88164d2c --- /dev/null +++ b/docs/models/qwen35/dflash-speculative-decoding.md @@ -0,0 +1,100 @@ +# DFlash Speculative Decoding (Qwen3.5-4B) + +> **TL;DR:** Qwen3.5-4B DFlash speculative decoding is implemented behind `--dflash-draft-model-path`, default-off, greedy-only, single-GPU only, and now supports multi-active decode batches with a fixed-buffer batched verifier. Same-host RTX 5090 A/B on `output_len=256` shows clear throughput wins at c4/c8/c16: decode-heavy `prompt_len=1` improves `+16.7%/+15.4%/+14.0%`, medium `prompt_len=1024` improves `+209.9%/+168.6%/+45.3%`, and long `prompt_len=4096` improves `+135.9%/+35.7%/+25.6%`. + +Last touched: 2026-07 + +## How To Enable + +Use a Qwen3.5 target model with a matching DFlash draft checkpoint: + +```bash +cargo run --release -p openinfer-server --features qwen35-4b -- \ + --model-path \ + --dflash-draft-model-path +``` + +The flag is rejected for unsupported model lines. Qwen3.5 DFlash is incompatible with LoRA, KV offload, tensor parallelism, and decode-overlap modes. Non-greedy requests and logprobs use normal decode. + +## Runtime Contract + +- The drafter emits `[current_token, draft...]`; the target verifies that span and commits the longest greedy-matching prefix plus one bonus token. +- Verification uses preallocated `VerifyBuffers35` storage for token ids, hidden/logit buffers, GDR scratch, full-attention scratch, paged prefill plans, and sampling scratch. Decode steps reuse fixed buffers instead of allocating on the hot path. +- Qwen3.5 verification is a hybrid transaction over full-attention KV, recurrent state, convolution state, and sequence length. Verify writes to scratch state; commit preserves full-span accepts directly and replays only truncated accepted spans after rolling KV back to the canonical boundary. +- Batched verify handles active batches up to the scheduler bucket size. Complete fixed shapes can use captured graph-compatible paths; truncated or heterogeneous spans use eager verify. +- The scheduler captures target hidden context only on DFlash-eligible prefill paths. If a request falls back to normal decode, its DFlash state is dropped because normal decode does not capture the hidden context needed by the drafter. +- Per-request low-acceptance statistics disable DFlash after enough poor draft tokens, so incompatible prompts return to baseline decode. +- DFlash reserves memory for draft weights, draft KV/cache, verify buffers, and batch scratch before target KV sizing. Admission also reserves draft block headroom, so a near-window request accepted without DFlash can be rejected when DFlash is enabled. + +## Validation + +Commands below passed on an RTX 5090 validation host with driver `580.105.08`, CUDA 13.3, Triton Python `3.7.1`, and `OPENINFER_CUDA_SM=120`. The source snapshot is the PR branch after the benchmark-shaped gate cleanup. + +```bash +cargo fmt --all --check +git diff --check +OPENINFER_TRITON_PYTHON= OPENINFER_TEST_MODEL_PATH= \ + OPENINFER_DFLASH_TEST_MODEL_PATH= \ + cargo test --release -p openinfer-qwen35-4b --features qwen35-4b \ + --test dflash_speculative_gate -- --nocapture --test-threads=1 +OPENINFER_TRITON_PYTHON= OPENINFER_TEST_MODEL_PATH= \ + OPENINFER_DFLASH_TEST_MODEL_PATH= \ + cargo test --release -p openinfer-qwen35-4b --features qwen35-4b \ + --test speculative_verify -- --nocapture --test-threads=1 +OPENINFER_TRITON_PYTHON= OPENINFER_TEST_MODEL_PATH= \ + cargo test --release -p openinfer-qwen35-4b --features qwen35-4b \ + --test hf_golden_gate -- --nocapture +OPENINFER_TRITON_PYTHON= OPENINFER_TEST_MODEL_PATH= \ + cargo test --release -p openinfer-qwen35-4b --features qwen35-4b \ + --test e2e_scheduler -- --nocapture +``` + +The DFlash scheduler gates check single request, multi-active batch, heterogeneous `max_tokens`, mixed concurrent requests, and the benchmark-shaped synthetic cases that exposed hash differences in the raw sweep (`1024/c16`, `4096/c8`, `4096/c16`). The benchmark-shaped follow-up passed: `1024/c16` was exact for 16/16 requests; `4096/c8` and `4096/c16` were exact except for near-ties accepted by the regret oracle (`regret 0.000` and `0.125 <= 0.20`). + +## Benchmark + +Same host, same PR branch snapshot, in-process `bench_serving request`, greedy synthetic distinct prompts, `output_len=256`, warmup `3`, iters `8`. + +| Prompt | Concurrency | Baseline tok/s | DFlash tok/s | Delta | Baseline effective TPOT p50 | DFlash effective TPOT p50 | Baseline raw ITL p99 | DFlash raw ITL p99 | Baseline TTFT p50 | DFlash TTFT p50 | +| ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | +| 1 | 1 | 151.214 | 149.808 | -0.9% | 6.593 ms | 6.645 ms | 6.682 ms | 6.699 ms | 9.122 ms | 9.374 ms | +| 1 | 4 | 110.906 | 129.388 | +16.7% | 8.907 ms | 8.682 ms | 8.988 ms | 21.412 ms | 39.045 ms | 32.756 ms | +| 1 | 8 | 89.977 | 103.856 | +15.4% | 10.889 ms | 9.679 ms | 10.969 ms | 33.776 ms | 69.832 ms | 63.528 ms | +| 1 | 16 | 64.930 | 73.990 | +14.0% | 14.925 ms | 13.851 ms | 15.054 ms | 57.610 ms | 131.570 ms | 125.421 ms | +| 1024 | 1 | 135.543 | 134.699 | -0.6% | 7.220 ms | 7.270 ms | 7.295 ms | 7.297 ms | 46.715 ms | 46.797 ms | +| 1024 | 4 | 97.293 | 301.482 | +209.9% | 9.911 ms | 2.980 ms | 9.695 ms | 19.231 ms | 153.577 ms | 137.400 ms | +| 1024 | 8 | 76.916 | 206.606 | +168.6% | 12.217 ms | 4.062 ms | 54.032 ms | 27.162 ms | 263.906 ms | 229.206 ms | +| 1024 | 16 | 53.404 | 77.602 | +45.3% | 16.963 ms | 11.597 ms | 60.353 ms | 52.408 ms | 492.746 ms | 414.550 ms | +| 4096 | 1 | 102.477 | 101.745 | -0.7% | 9.039 ms | 9.122 ms | 9.139 ms | 9.134 ms | 189.535 ms | 189.955 ms | +| 4096 | 4 | 68.916 | 162.581 | +135.9% | 12.830 ms | 4.635 ms | 57.351 ms | 22.665 ms | 640.507 ms | 567.550 ms | +| 4096 | 8 | 50.473 | 68.502 | +35.7% | 16.238 ms | 11.677 ms | 61.075 ms | 35.653 ms | 1106.224 ms | 941.275 ms | +| 4096 | 16 | 32.875 | 41.304 | +25.6% | 23.239 ms | 17.710 ms | 65.017 ms | 63.604 ms | 2070.313 ms | 1696.621 ms | + +`effective_tpot_ms` is the amortized per-request decode time. Raw token-event ITL can spike under speculative decode because accepted spans emit multiple tokens in one scheduler step; keep both metrics visible when reviewing tails. + +## Profile + +Profiles used `nsys profile --trace=cuda,nvtx,osrt --cuda-graph-trace=node` and `nsys stats` on the same host. The final c8/c16 traces show that the previous per-request verifier bottleneck is gone: DFlash uses batched prefill verify kernels and partial-only replay instead of singleton target-prefill verification. + +| Shape | Baseline dominant work | DFlash dominant work | Profile conclusion | +| --- | --- | --- | --- | +| `prompt=1,c=8` | `gated_delta_rule_decode_kernel` `2.04s`, batch decode attention `72.6ms` | GDR verify kernels plus lower target decode counts; batch decode attention `71.2ms` | Draft/verify overhead is below the throughput saved by multi-token accepts. | +| `prompt=1024,c=8` | `gated_delta_rule_decode_kernel` `2.06s`, batch decode attention `550.2ms` | GDR verify kernels, `SinglePrefillWithKVCacheKernel` `75.3ms`, batch prefill verify `49.6ms`, batch decode attention `71.8ms` | Verifier no longer runs target prefill per request; c8 decode throughput improves `+192.90%`. | +| `prompt=4096,c=16` | `gated_delta_rule_decode_kernel` `4.41s`, batch decode attention `2.44s`, batch prefill `398.5ms` | batch decode attention `1.68s`, batch prefill verify `537.9ms`, GDR verify kernels visible but not dominant | Commit/replay/copy is not the leading bottleneck; long c16 still improves `+24.89%`. | + +## Claim Boundaries + +- This is an opt-in Qwen3.5 DFlash path with real c4/c8/c16 in-process benchmark wins. Token sanity is exact where stable; prompt-length-1 and a few long high-concurrency synthetic cases are covered by the same regret oracle used by the scheduler gate for bf16 near-tie / prefill-vs-decode boundary flips. +- The performance table is in-process benchmark evidence. Do not read it as an HTTP serving pressure claim. +- Single-concurrency random synthetic prompts remain flat to slightly slower. The multi-active path is the supported performance claim for this slice. +- Multi-GPU, LoRA, KV offload, decode overlap, non-greedy sampling, and logprobs intentionally use normal decode or fail closed. + +## Remaining Risks And Follow-ups + +- No blocker-level implementation risk is known from the current local, GPU, benchmark, and profile evidence. Keep CI state and new reviewer comments as the final merge gate because they can change after the local evidence snapshot. +- Single-request `c1` runs are flat to slightly slower (`-0.6%` to `-0.9%` in the benchmark table). DFlash should be described as a multi-active throughput path, not as an all-shape latency win. +- Raw token-event ITL p99 can increase under short-prompt speculative decode because accepted spans emit multiple tokens in one scheduler step. Keep raw ITL visible next to effective TPOT and output throughput when reviewing tail latency. +- The benchmark table is in-process serving evidence. A production-style HTTP pressure sweep remains useful before making broader OpenAI-compatible serving claims. +- A few synthetic high-concurrency shapes are validated by the regret oracle or prefill-vs-decode boundary check instead of exact raw hash equality. This is covered by scheduler gates, but reviewer discussion should keep the oracle boundary explicit. +- DFlash reserves extra memory for draft weights, draft KV/cache, verify buffers, and batch scratch. Near-window or near-memory requests may be admitted by baseline decode and rejected with DFlash enabled. +- Unsupported modes remain intentional scope boundaries: tensor parallelism, LoRA, KV offload, decode overlap, non-greedy sampling, and logprobs use normal decode or fail closed. diff --git a/docs/models/qwen35/roadmap.md b/docs/models/qwen35/roadmap.md index b8950bd3..fb8bd36e 100644 --- a/docs/models/qwen35/roadmap.md +++ b/docs/models/qwen35/roadmap.md @@ -1,8 +1,8 @@ # Qwen3.5-4B Roadmap -> **TL;DR:** Qwen3.5-4B is decode-correct and still improving: the decode-tuning refresh improves direct TPOT by `2.1-3.2%`, while vLLM still leads 1024/256 HTTP decode and high-concurrency throughput. Long-prompt HF logits and GSM8K gates cover the old 4096-position RoPE boundary. Remaining structural items are HND prefill staging, prefix-cache design, and the serving-level concurrency gap. +> **TL;DR:** Qwen3.5-4B is decode-correct and still improving: DFlash speculative decoding is now opt-in, default-off, and batched for multi-active decode. Same-host RTX 5090 A/B shows c4/c8/c16 gains for decode-heavy, medium, and long prompts; token sanity is exact where stable and covered by the DFlash regret oracle for bf16 near-tie cases. Remaining structural items are HND prefill staging, prefix-cache design, serving-level HTTP pressure validation for DFlash, and broader non-greedy/TP feature coverage. > -> **Last touched:** 2026-06 +> **Last touched:** 2026-07 Tracking issue: see the `[Model] Qwen3.5-4B roadmap` GitHub issue. Sibling doc: `docs/models/qwen3/roadmap.md` — batched sampling is shared and #284 now routes Qwen3.5 decode through the same compact batched sampler; Qwen3.5 now has its own model-level non-greedy behavior gate, while qwen3 keeps the sibling gate on its side. @@ -20,6 +20,7 @@ Tracking issue: see the `[Model] Qwen3.5-4B roadmap` GitHub issue. Sibling doc: | Admission | ✓ existing full-lifetime KV admission and explicit `Rejected` events cover impossible KV requests; #253 adds the context-window rejection reason before prefill/decode | `scheduler.rs`, `src/scheduler/plan.rs`, `docs/models/qwen35/kv-admission.md` | | Scheduler tests | Partial: current plan selection, full-lifetime admission, context-window rejection, slot assignment, and slot-compaction decisions are CPU-tested; GPU execution remains coupled to the production scheduler | `src/scheduler/plan.rs` | | Step tail | Local branch verified: #353 batches the prefill final norm/lm_head tail, samples decode/unified rows from batched logits, and keeps host full-vocab copies only for requested logprobs; HF/e2e gates pass, short-output serving A/B shows TTFT benefit, long-decode TPOT remains a no-claim diagnostic | `docs/models/qwen35/batched-step-tail.md` | +| DFlash speculative decode | Opt-in batched path: hybrid-state verify/commit covers KV + recurrent + conv state; correctness and scheduler e2e gates pass. Same-host in-process A/B improves `prompt_len=1` c4/c8/c16 by `+16.7%/+15.4%/+14.0%`, `prompt_len=1024` by `+209.9%/+168.6%/+45.3%`, and `prompt_len=4096` by `+135.9%/+35.7%/+25.6%`. | `docs/models/qwen35/dflash-speculative-decoding.md` | | TP | ✗ absent (single GPU only) | — | | Prefix cache | ✗ absent; recurrent GDR state (~48MB per boundary snapshot) makes "prefix hit" itself a design question | — | @@ -34,10 +35,11 @@ Tracking issue: see the `[Model] Qwen3.5-4B roadmap` GitHub issue. Sibling doc: ### Next -5. **Prefill full-paged migration.** Replace the HND staging copy with direct paged writes: removes the ~640MB transient and the extra D2D pass. Chain dependency: paged-direct prefill → per-token position plumbing → RoPE/context-window invariants → opens the door to prefix-cache design. -6. **Serving-level concurrency profiling.** Add a measured-only server-side range, then split the 1024/256 concurrency-16 gap across scheduler wait, event sync, request dispatch, and model execution. Also teach the Qwen3.5 direct decode bench to prove cached-token exclusion before it reports pure decode TPOT. -7. **Scheduler logic seam follow-through.** The current admission/slot/compaction decisions have a CPU-tested seam. Keep future admission and rejection changes in that seam instead of re-embedding them in GPU execution. -8. **Prefix-cache design note.** Linear-attention layers carry recurrent state, not KV blocks — a "prefix hit" must restore both the full-attention KV *and* a recurrent-state snapshot at a block boundary (~48MB per boundary at bf16). Whether to snapshot per block, per N blocks, or only at request end is an open trade; write the design note before any code. Depends on 5. +5. **DFlash serving pressure validation.** The batched in-process DFlash path is now positive for c4/c8/c16. Next evidence step is HTTP/OpenAI-compatible pressure with the same baseline-vs-DFlash contract, including completed/failed counts, TTFT, effective TPOT, raw ITL, and token hash sanity. +6. **Prefill full-paged migration.** Replace the HND staging copy with direct paged writes: removes the ~640MB transient and the extra D2D pass. Chain dependency: paged-direct prefill → per-token position plumbing → RoPE/context-window invariants → opens the door to prefix-cache design. +7. **Serving-level concurrency profiling.** Add a measured-only server-side range, then split the 1024/256 concurrency-16 gap across scheduler wait, event sync, request dispatch, and model execution. Also teach the Qwen3.5 direct decode bench to prove cached-token exclusion before it reports pure decode TPOT. +8. **Scheduler logic seam follow-through.** The current admission/slot/compaction decisions have a CPU-tested seam. Keep future admission and rejection changes in that seam instead of re-embedding them in GPU execution. +9. **Prefix-cache design note.** Linear-attention layers carry recurrent state, not KV blocks — a "prefix hit" must restore both the full-attention KV *and* a recurrent-state snapshot at a block boundary (~48MB per boundary at bf16). Whether to snapshot per block, per N blocks, or only at request end is an open trade; write the design note before any code. Depends on 6. ### Later diff --git a/openinfer-core/src/kv_pool.rs b/openinfer-core/src/kv_pool.rs index f1e170fa..2cd0e142 100644 --- a/openinfer-core/src/kv_pool.rs +++ b/openinfer-core/src/kv_pool.rs @@ -213,6 +213,20 @@ impl KvState { self.seq_len += count; } + /// Roll this request's logical KV length back to `token_count`, returning + /// any now-unused tail pages to the pool. + pub fn truncate_to(&mut self, token_count: usize) -> Result<()> { + anyhow::ensure!( + token_count <= self.seq_len, + "KvState cannot truncate from {} up to {token_count}", + self.seq_len + ); + let needed = pages_needed(token_count, self.pool.inner.layout.page_size); + self.permit.truncate(needed); + self.seq_len = token_count; + Ok(()) + } + /// Build kernel-facing metadata for this request's KV. pub fn desc(&self) -> KvDesc<'_> { KvDesc { @@ -348,12 +362,39 @@ mod tests { assert_eq!(desc.last_page_len(), 1); assert_eq!(pool.available_pages(), 2); + // Truncate back into the first page: tail page returns immediately. + kv.truncate_to(15).unwrap(); + assert_eq!(kv.seq_len(), 15); + let desc = kv.desc(); + assert_eq!(desc.num_pages(), 1); + assert_eq!(desc.last_page_len(), 15); + assert_eq!(pool.available_pages(), 3); + + // Truncate to zero releases all request pages. + kv.truncate_to(0).unwrap(); + assert_eq!(kv.seq_len(), 0); + assert_eq!(kv.desc().num_pages(), 0); + assert_eq!(pool.available_pages(), 4); + // Reset returns all pages + kv.ensure_capacity(17).unwrap(); + kv.advance(17); kv.reset(); assert_eq!(kv.seq_len(), 0); assert_eq!(pool.available_pages(), 4); } + #[test] + fn kv_state_rejects_truncate_forward() { + let pool = test_pool(16, 3); + let mut kv = pool.alloc(); + kv.ensure_capacity(4).unwrap(); + kv.advance(4); + + let err = kv.truncate_to(5).unwrap_err().to_string(); + assert!(err.contains("cannot truncate from 4 up to 5")); + } + #[test] fn kv_state_out_of_pages() { // 3 pages total: 1 padding, 2 available → 32 tokens max diff --git a/openinfer-core/src/ops.rs b/openinfer-core/src/ops.rs index 389ea51a..9916f2f8 100644 --- a/openinfer-core/src/ops.rs +++ b/openinfer-core/src/ops.rs @@ -25,7 +25,7 @@ pub use openinfer_kernels::ops::{ rms_norm_gated_batch_into, rms_norm_into, rms_norm_offset_into, scale_f32_in_place, scaled_add_batch_into, scaled_add_rows_indexed_into, scaled_add_rows_into, scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, - single_prefill_nhd_noncausal_into, write_vec_into, + single_prefill_nhd_causal_window_into, single_prefill_nhd_noncausal_into, write_vec_into, }; #[cfg(not(feature = "kernel-call-trace"))] pub use openinfer_kernels::ops::{ diff --git a/openinfer-core/src/page_pool.rs b/openinfer-core/src/page_pool.rs index 2b702e23..1ff1304d 100644 --- a/openinfer-core/src/page_pool.rs +++ b/openinfer-core/src/page_pool.rs @@ -107,6 +107,24 @@ impl OwnedPagePermit { } true } + + /// Return tail pages until the permit holds exactly `new_len` pages. + /// + /// Prefix page order is preserved. Pages beyond `new_len` are returned to + /// the same pool immediately, matching the drop-time LIFO reuse order. + pub fn truncate(&mut self, new_len: usize) { + debug_assert!( + new_len <= self.pages.len(), + "cannot grow an OwnedPagePermit via truncate" + ); + if new_len >= self.pages.len() { + return; + } + + let returned = self.pages.split_off(new_len); + let mut free_list = self.inner.free_list.lock(); + free_list.extend(returned.into_iter().rev()); + } } impl Drop for OwnedPagePermit { @@ -182,4 +200,27 @@ mod tests { // all 4 pages back after drop assert_eq!(pool.available_pages(), 4); } + + #[test] + fn truncate_returns_tail_pages_and_preserves_prefix() { + let pool = PagePool::new(5); + + { + let mut permit = pool.try_acquire_many(4).expect("initial acquire"); + assert_eq!( + permit.pages(), + &[PageId(0), PageId(1), PageId(2), PageId(3)] + ); + assert_eq!(pool.available_pages(), 1); + + permit.truncate(2); + assert_eq!(permit.pages(), &[PageId(0), PageId(1)]); + assert_eq!(pool.available_pages(), 3); + + let next = pool.try_acquire_many(2).expect("tail pages reusable"); + assert_eq!(next.pages(), &[PageId(2), PageId(3)]); + } + + assert_eq!(pool.available_pages(), 5); + } } diff --git a/openinfer-kernels/csrc/qwen35/prefill_attention_hd256.cu b/openinfer-kernels/csrc/qwen35/prefill_attention_hd256.cu index cc485088..4f34c909 100644 --- a/openinfer-kernels/csrc/qwen35/prefill_attention_hd256.cu +++ b/openinfer-kernels/csrc/qwen35/prefill_attention_hd256.cu @@ -168,6 +168,160 @@ __global__ void prefill_v_cache_write_hd256_paged_kernel( kv_data[dst] = v_batch[src]; } +__device__ __forceinline__ int request_for_token_hd256( + int token, + const int* __restrict__ q_indptr, + int batch_size) { + for (int i = 0; i < batch_size; ++i) { + if (token < q_indptr[i + 1]) return i; + } + return batch_size - 1; +} + +__global__ void prefill_qk_norm_rope_hd256_paged_batch_kernel( + const __nv_bfloat16* __restrict__ q_full_batch, + const __nv_bfloat16* __restrict__ k_batch, + const __nv_bfloat16* __restrict__ q_norm_weight, + const __nv_bfloat16* __restrict__ k_norm_weight, + const __nv_bfloat16* __restrict__ cos_cache, + const __nv_bfloat16* __restrict__ sin_cache, + __nv_bfloat16* __restrict__ q_batch_out, + __nv_bfloat16* __restrict__ kv_data, + int64_t k_offset_elems, + const int* __restrict__ page_indices, + const int* __restrict__ page_indptr, + const int* __restrict__ q_indptr, + const int* __restrict__ positions, + int num_q_heads, + int num_kv_heads, + int total_tokens, + int batch_size, + int rotary_dim, + float rms_eps, + int page_size, + int64_t stride_page +) { + int token = blockIdx.x; + int head_global = blockIdx.y; + int d = threadIdx.x; + if (token >= total_tokens) return; + + bool is_q = head_global < num_q_heads; + int head_local = is_q ? head_global : (head_global - num_q_heads); + int q_full_dim = num_q_heads * HD256 * 2; + int q_dim = num_q_heads * HD256; + int kv_dim = num_kv_heads * HD256; + + int src_offset = is_q + ? token * q_full_dim + head_local * 2 * HD256 + d + : token * kv_dim + head_local * HD256 + d; + __nv_bfloat16 x = is_q ? q_full_batch[src_offset] : k_batch[src_offset]; + const __nv_bfloat16* norm_w = is_q ? q_norm_weight : k_norm_weight; + + float sq = __bfloat162float(x); + sq *= sq; + float sq_sum = warp_reduce_sum(sq); + + int warp_id = d / WARP_SIZE; + int lane_id = d % WARP_SIZE; + __shared__ float warp_sums[NUM_WARPS_HD256]; + __shared__ float inv_rms; + __shared__ __nv_bfloat16 smem[HD256]; + + if (lane_id == 0) warp_sums[warp_id] = sq_sum; + __syncthreads(); + + if (d == 0) { + float total = 0.0f; + for (int i = 0; i < NUM_WARPS_HD256; i++) total += warp_sums[i]; + inv_rms = 1.0f / sqrtf(total / HD256 + rms_eps); + } + __syncthreads(); + + smem[d] = rms_norm_elem_offset_hd256(x, inv_rms, norm_w[d]); + __syncthreads(); + + int pos = positions[token]; + int half_rotary = rotary_dim / 2; + int req = request_for_token_hd256(token, q_indptr, batch_size); + int local_pos = pos % page_size; + int page_list_start = page_indptr[req]; + int page_id = page_indices[page_list_start + pos / page_size]; + + if (d < half_rotary) { + __nv_bfloat16 lo = smem[d]; + __nv_bfloat16 hi = smem[d + half_rotary]; + apply_rope_pair_hd256( + lo, + hi, + cos_cache[pos * rotary_dim + d], + sin_cache[pos * rotary_dim + d] + ); + + if (is_q) { + int dst = token * q_dim + head_local * HD256; + q_batch_out[dst + d] = lo; + q_batch_out[dst + d + half_rotary] = hi; + } else { + int64_t dst = static_cast(page_id) * stride_page + + k_offset_elems + + static_cast(local_pos) * num_kv_heads * HD256 + + static_cast(head_local) * HD256 + + d; + kv_data[dst] = lo; + kv_data[dst + half_rotary] = hi; + } + } + + if (d >= rotary_dim) { + if (is_q) { + int dst = token * q_dim + head_local * HD256; + q_batch_out[dst + d] = smem[d]; + } else { + int64_t dst = static_cast(page_id) * stride_page + + k_offset_elems + + static_cast(local_pos) * num_kv_heads * HD256 + + static_cast(head_local) * HD256 + + d; + kv_data[dst] = smem[d]; + } + } +} + +__global__ void prefill_v_cache_write_hd256_paged_batch_kernel( + const __nv_bfloat16* __restrict__ v_batch, + __nv_bfloat16* __restrict__ kv_data, + int64_t v_offset_elems, + const int* __restrict__ page_indices, + const int* __restrict__ page_indptr, + const int* __restrict__ q_indptr, + const int* __restrict__ positions, + int num_kv_heads, + int total_tokens, + int batch_size, + int page_size, + int64_t stride_page +) { + int token = blockIdx.x; + int kv_head = blockIdx.y; + int d = threadIdx.x; + if (token >= total_tokens) return; + + int pos = positions[token]; + int req = request_for_token_hd256(token, q_indptr, batch_size); + int page_list_start = page_indptr[req]; + int page_id = page_indices[page_list_start + pos / page_size]; + int local_pos = pos % page_size; + int kv_dim = num_kv_heads * HD256; + int src = token * kv_dim + kv_head * HD256 + d; + int64_t dst = static_cast(page_id) * stride_page + + v_offset_elems + + static_cast(local_pos) * num_kv_heads * HD256 + + static_cast(kv_head) * HD256 + + d; + kv_data[dst] = v_batch[src]; +} + __global__ void attention_gate_batch_hd256_kernel( const __nv_bfloat16* __restrict__ q_full_batch, // [q_full_dim, seq_len] __nv_bfloat16* __restrict__ attn_out, // [q_dim, seq_len] @@ -384,6 +538,74 @@ void prefill_attention_hd256_prep_paged_cuda( ); } +void prefill_attention_hd256_prep_paged_batch_cuda( + const __nv_bfloat16* q_full_batch, + const __nv_bfloat16* k_batch, + const __nv_bfloat16* v_batch, + const __nv_bfloat16* q_norm_weight, + const __nv_bfloat16* k_norm_weight, + const __nv_bfloat16* cos_cache, + const __nv_bfloat16* sin_cache, + __nv_bfloat16* q_batch_out, + __nv_bfloat16* kv_data, + int64_t k_offset_elems, + int64_t v_offset_elems, + const int* page_indices, + const int* page_indptr, + const int* q_indptr, + const int* positions, + int num_q_heads, + int num_kv_heads, + int total_tokens, + int batch_size, + int rotary_dim, + float rms_eps, + int page_size, + int64_t stride_page, + cudaStream_t stream +) { + dim3 prep_grid(total_tokens, num_q_heads + num_kv_heads); + prefill_qk_norm_rope_hd256_paged_batch_kernel<<>>( + q_full_batch, + k_batch, + q_norm_weight, + k_norm_weight, + cos_cache, + sin_cache, + q_batch_out, + kv_data, + k_offset_elems, + page_indices, + page_indptr, + q_indptr, + positions, + num_q_heads, + num_kv_heads, + total_tokens, + batch_size, + rotary_dim, + rms_eps, + page_size, + stride_page + ); + + dim3 v_grid(total_tokens, num_kv_heads); + prefill_v_cache_write_hd256_paged_batch_kernel<<>>( + v_batch, + kv_data, + v_offset_elems, + page_indices, + page_indptr, + q_indptr, + positions, + num_kv_heads, + total_tokens, + batch_size, + page_size, + stride_page + ); +} + void attention_gate_batch_hd256_cuda( const __nv_bfloat16* q_full_batch, __nv_bfloat16* attn_out, diff --git a/openinfer-kernels/csrc/shared/paged_attention.cu b/openinfer-kernels/csrc/shared/paged_attention.cu index 01f4e82c..8f3cdd1f 100644 --- a/openinfer-kernels/csrc/shared/paged_attention.cu +++ b/openinfer-kernels/csrc/shared/paged_attention.cu @@ -671,6 +671,71 @@ int single_prefill_nhd_noncausal_cuda( reinterpret_cast(stream))); } +int single_prefill_nhd_causal_window_cuda( + // Q and output (HiddenStates token-major: [seq_len, q_dim]) + void* q, + void* output, + // Contiguous KV cache (HiddenStates token-major: [max_seq_len, kv_dim]) + void* k_cache, + void* v_cache, + int32_t num_qo_heads, + int32_t num_kv_heads, + int32_t head_dim, + int32_t seq_len, + int32_t kv_len, + int32_t max_seq_len, + int32_t window_left, + float sm_scale, + void* stream) +{ + if (q == nullptr || output == nullptr || k_cache == nullptr || v_cache == nullptr || + num_qo_heads <= 0 || num_kv_heads <= 0 || head_dim != 128 || + seq_len <= 0 || kv_len <= 0 || max_seq_len < kv_len || window_left < 0) { + return static_cast(cudaErrorInvalidValue); + } + + uint32_t q_stride_n = num_qo_heads * head_dim; + uint32_t q_stride_h = head_dim; + uint32_t kv_stride_n = num_kv_heads * head_dim; + uint32_t kv_stride_h = head_dim; + + PrefillParamsT params( + reinterpret_cast(q), + reinterpret_cast(k_cache), + reinterpret_cast(v_cache), + /*maybe_custom_mask=*/nullptr, + reinterpret_cast(output), + /*lse=*/nullptr, + /*maybe_alibi_slopes=*/nullptr, + num_qo_heads, + num_kv_heads, + static_cast(seq_len), + static_cast(kv_len), + q_stride_n, + q_stride_h, + kv_stride_n, + kv_stride_h, + static_cast(head_dim), + window_left, + /*logits_soft_cap=*/0.0f, + sm_scale, + /*rope_scale=*/1.0f, + /*rope_theta=*/1e6f); + + return static_cast( + SinglePrefillWithKVCacheDispatched< + /*HEAD_DIM_QK=*/128, + /*HEAD_DIM_VO=*/128, + PosEncodingMode::kNone, + /*USE_FP16_QK_REDUCTION=*/false, + MaskMode::kCausal, + Variant, + PrefillParamsT>( + params, + /*tmp=*/nullptr, + reinterpret_cast(stream))); +} + // --------------------------------------------------------------------------- // Single-request prefill for HEAD_DIM=256 — wraps FlashInfer SinglePrefillWithKVCache. // diff --git a/openinfer-kernels/src/ffi/qwen35.rs b/openinfer-kernels/src/ffi/qwen35.rs index 74df598c..55c65b1e 100644 --- a/openinfer-kernels/src/ffi/qwen35.rs +++ b/openinfer-kernels/src/ffi/qwen35.rs @@ -31,6 +31,33 @@ unsafe extern "C" { stream: CUstream, ); + pub fn prefill_attention_hd256_prep_paged_batch_cuda( + q_full_batch: *const Half, + k_batch: *const Half, + v_batch: *const Half, + q_norm_weight: *const Half, + k_norm_weight: *const Half, + cos_cache: *const Half, + sin_cache: *const Half, + q_batch_out: *mut Half, + kv_data: *mut Half, + k_offset_elems: i64, + v_offset_elems: i64, + page_indices: *const i32, + page_indptr: *const i32, + q_indptr: *const i32, + positions: *const i32, + num_q_heads: i32, + num_kv_heads: i32, + total_tokens: i32, + batch_size: i32, + rotary_dim: i32, + rms_eps: f32, + page_size: i32, + stride_page: i64, + stream: CUstream, + ); + // Apply sigmoid(gate) from interleaved q_full onto attention output in-place. pub fn attention_gate_batch_hd256_cuda( q_full_batch: *const Half, diff --git a/openinfer-kernels/src/ffi/shared.rs b/openinfer-kernels/src/ffi/shared.rs index 0f9c547a..68ea36b9 100644 --- a/openinfer-kernels/src/ffi/shared.rs +++ b/openinfer-kernels/src/ffi/shared.rs @@ -449,6 +449,22 @@ unsafe extern "C" { stream: CUstream, ) -> i32; + pub fn single_prefill_nhd_causal_window_cuda( + q: *const Half, + output: *mut Half, + k_cache: *const Half, + v_cache: *const Half, + num_qo_heads: i32, + num_kv_heads: i32, + head_dim: i32, + seq_len: i32, + kv_len: i32, + max_seq_len: i32, + window_left: i32, + sm_scale: f32, + stream: CUstream, + ) -> i32; + // Paged attention decode (FlashInfer BatchDecode, no partition-KV). pub fn paged_attention_decode_cuda( q: *const Half, diff --git a/openinfer-kernels/src/ops.rs b/openinfer-kernels/src/ops.rs index a8b843f2..73a749ad 100644 --- a/openinfer-kernels/src/ops.rs +++ b/openinfer-kernels/src/ops.rs @@ -20,7 +20,8 @@ pub use attention::{ PrefillPagedPlan, dflash_qk_norm_rope_into, paged_attention_batch_decode_hd256_into, paged_attention_batch_decode_into, paged_attention_batch_decode_split_kv_into, prefill_attention_paged_into, qk_norm_partial_rope_batched_decode_hd256_into, - qk_norm_rope_batch_decode_into, single_prefill_nhd_noncausal_into, + qk_norm_rope_batch_decode_into, single_prefill_nhd_causal_window_into, + single_prefill_nhd_noncausal_into, }; #[cfg(feature = "moe")] pub use deepep::{ diff --git a/openinfer-kernels/src/ops/attention.rs b/openinfer-kernels/src/ops/attention.rs index 55b301d3..179bc92a 100644 --- a/openinfer-kernels/src/ops/attention.rs +++ b/openinfer-kernels/src/ops/attention.rs @@ -836,6 +836,70 @@ pub fn single_prefill_nhd_noncausal_into( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn single_prefill_nhd_causal_window_into( + ctx: &DeviceContext, + q: &HiddenStates, + row_offset: usize, + q_seq_len: usize, + k_cache: &HiddenStates, + v_cache: &HiddenStates, + output: &mut HiddenStates, + num_q_heads: usize, + num_kv_heads: usize, + head_dim: usize, + kv_len: usize, + window_left: usize, +) -> Result<()> { + anyhow::ensure!( + head_dim == 128, + "single_prefill_nhd_causal_window_into supports head_dim=128, got {head_dim}" + ); + assert_eq!(q.hidden_dim, num_q_heads * head_dim); + assert_eq!(output.hidden_dim, q.hidden_dim); + assert_eq!(output.seq_len, q.seq_len); + assert_eq!(k_cache.hidden_dim, num_kv_heads * head_dim); + assert_eq!(v_cache.hidden_dim, k_cache.hidden_dim); + assert_eq!(v_cache.seq_len, k_cache.seq_len); + assert!(kv_len <= k_cache.seq_len); + assert!( + row_offset + q_seq_len <= q.seq_len, + "single_prefill causal-window row range [{}..{}) exceeds seq_len {}", + row_offset, + row_offset + q_seq_len, + q.seq_len + ); + + let byte_offset = (row_offset * q.hidden_dim * std::mem::size_of::()) as u64; + let (q_ptr, _gq) = q.data.device_ptr(&ctx.stream); + let q_ptr = q_ptr + byte_offset; + let (k_ptr, _gk) = k_cache.data.device_ptr(&ctx.stream); + let (v_ptr, _gv) = v_cache.data.device_ptr(&ctx.stream); + let (out_ptr, _go) = output.data.device_ptr_mut(&ctx.stream); + let out_ptr = out_ptr + byte_offset; + let result = unsafe { + ffi::single_prefill_nhd_causal_window_cuda( + q_ptr as *const ffi::Half, + out_ptr as *mut ffi::Half, + k_ptr as *const ffi::Half, + v_ptr as *const ffi::Half, + num_q_heads as i32, + num_kv_heads as i32, + head_dim as i32, + q_seq_len as i32, + kv_len as i32, + k_cache.seq_len as i32, + window_left as i32, + 1.0f32 / (head_dim as f32).sqrt(), + crate::tensor::active_cu_stream(ctx), + ) + }; + if result != 0 { + anyhow::bail!("single_prefill_nhd_causal_window_cuda failed with error {result}"); + } + Ok(()) +} + /// Batched QK RMSNorm + partial RoPE for Qwen3.5 HD256 decode. /// /// Reads Q from interleaved `q_full` ([q, gate] per head), writes prepared Q into `q`, diff --git a/openinfer-qwen35-4b/Cargo.toml b/openinfer-qwen35-4b/Cargo.toml index 4b413cfe..4176cc09 100644 --- a/openinfer-qwen35-4b/Cargo.toml +++ b/openinfer-qwen35-4b/Cargo.toml @@ -49,6 +49,14 @@ required-features = ["qwen35-4b"] name = "chunked_prefill" required-features = ["qwen35-4b"] +[[test]] +name = "dflash_speculative_gate" +required-features = ["qwen35-4b"] + +[[test]] +name = "dflash_speculative_perf" +required-features = ["qwen35-4b"] + [[bench]] name = "qwen35_ops" harness = false diff --git a/openinfer-qwen35-4b/src/batch_decode.rs b/openinfer-qwen35-4b/src/batch_decode.rs index a2359622..6c60b110 100644 --- a/openinfer-qwen35-4b/src/batch_decode.rs +++ b/openinfer-qwen35-4b/src/batch_decode.rs @@ -157,6 +157,16 @@ impl Qwen35Model { token_ids: &[u32], kv_states: &mut [&mut KvState], graph_state: &mut BatchDecodeGraphState, + ) -> Result<()> { + self.batch_decode_graph_with_capture(token_ids, kv_states, graph_state, None) + } + + pub(crate) fn batch_decode_graph_with_capture( + &self, + token_ids: &[u32], + kv_states: &mut [&mut KvState], + graph_state: &mut BatchDecodeGraphState, + capture_layer_ids: Option<&[usize]>, ) -> Result<()> { let bs = token_ids.len(); anyhow::ensure!(bs > 0, "batch_decode_graph requires at least one request"); @@ -166,6 +176,16 @@ impl Qwen35Model { "batch size {bs} exceeds MAX_BATCH={}", super::batch_decode_graph::MAX_BATCH ); + if let Some(ids) = capture_layer_ids { + anyhow::ensure!( + ids.windows(2).all(|pair| pair[0] < pair[1]), + "Qwen3.5 decode capture layer ids must be strictly increasing" + ); + anyhow::ensure!( + ids.iter().all(|&idx| idx < self.config.num_hidden_layers), + "Qwen3.5 decode capture layer id out of range" + ); + } let padded_bs = bucket_for(bs); @@ -182,6 +202,9 @@ impl Qwen35Model { } graph_state.buffers.set_batch_size(padded_bs); + if let Some(ids) = capture_layer_ids { + graph_state.buffers.captured_hidden.hidden_dim = self.config.hidden_size * ids.len(); + } // H2D: token_ids and positions — zero-padded to bucket size. let mut token_ids_padded = token_ids.to_vec(); @@ -204,18 +227,35 @@ impl Qwen35Model { let layout = *kv_states[0].layout(); let bucket_idx = BATCH_BUCKETS.iter().position(|&b| b == padded_bs).unwrap(); - // Take graphs out of graph_state to avoid split-borrow in the closure. - let mut graphs = std::mem::take(&mut graph_state.graphs); - let result = graphs[bucket_idx].run_or_capture(&self.ctx, || { - self.batch_decode_kernels_graph( - kv_buffer, - &layout, - padded_bs, - &mut graph_state.slot_states, - &mut graph_state.buffers, - ) - }); - graph_state.graphs = graphs; + let result = if let Some(capture_layer_ids) = capture_layer_ids { + let mut capture_graphs = std::mem::take(&mut graph_state.capture_graphs); + let result = capture_graphs[bucket_idx].run_or_capture(&self.ctx, || { + self.batch_decode_kernels_graph( + kv_buffer, + &layout, + padded_bs, + &mut graph_state.slot_states, + &mut graph_state.buffers, + Some(capture_layer_ids), + ) + }); + graph_state.capture_graphs = capture_graphs; + result + } else { + let mut graphs = std::mem::take(&mut graph_state.graphs); + let result = graphs[bucket_idx].run_or_capture(&self.ctx, || { + self.batch_decode_kernels_graph( + kv_buffer, + &layout, + padded_bs, + &mut graph_state.slot_states, + &mut graph_state.buffers, + None, + ) + }); + graph_state.graphs = graphs; + result + }; result } @@ -226,8 +266,10 @@ impl Qwen35Model { padded_bs: usize, slot_states: &mut [RecurrentState], bufs: &mut BatchDecodeBuffers35, + capture_layer_ids: Option<&[usize]>, ) -> Result<()> { let eps = self.config.rms_norm_eps; + let capture_layer_ids = capture_layer_ids.unwrap_or(&[]); ops::embedding_batch( &self.ctx, @@ -238,7 +280,7 @@ impl Qwen35Model { let mut linear_idx = 0usize; let mut full_idx = 0usize; - for layer in &self.layers { + for (layer_idx, layer) in self.layers.iter().enumerate() { ops::rms_norm_batch_offset_into( &self.ctx, &bufs.hidden, @@ -296,6 +338,15 @@ impl Qwen35Model { ); ops::add_batch_into(&self.ctx, &bufs.hidden_mid, &bufs.mlp_out, &mut bufs.hidden)?; + + if let Some(capture_slot) = capture_layer_ids.iter().position(|&idx| idx == layer_idx) { + ops::copy_hidden_rows_into( + &self.ctx, + &bufs.hidden, + &mut bufs.captured_hidden, + capture_slot * self.config.hidden_size, + )?; + } } ops::rms_norm_batch_offset_into( diff --git a/openinfer-qwen35-4b/src/batch_decode_graph.rs b/openinfer-qwen35-4b/src/batch_decode_graph.rs index 31e4712c..959b3842 100644 --- a/openinfer-qwen35-4b/src/batch_decode_graph.rs +++ b/openinfer-qwen35-4b/src/batch_decode_graph.rs @@ -59,6 +59,9 @@ pub(crate) struct BatchDecodeGraphState { pub(crate) slot_states: Vec, /// One `CudaGraphState` per BATCH_BUCKETS entry (indexed by position). pub(crate) graphs: Vec, + /// One capture-enabled decode graph per bucket. DFlash hidden capture adds + /// copy kernels to the decode body, so it must not reuse the plain graph. + pub(crate) capture_graphs: Vec, } impl BatchDecodeGraphState { @@ -88,11 +91,16 @@ impl BatchDecodeGraphState { .iter() .map(|_| CudaGraphState::new()) .collect(); + let capture_graphs = BATCH_BUCKETS + .iter() + .map(|_| CudaGraphState::new()) + .collect(); Ok(Self { buffers, slot_states, graphs, + capture_graphs, }) } @@ -107,17 +115,67 @@ impl BatchDecodeGraphState { src: &RecurrentState, slot_idx: usize, ) -> Result<()> { - debug_assert!(slot_idx < MAX_BATCH, "slot_idx {slot_idx} out of range"); + anyhow::ensure!( + slot_idx < self.slot_states.len(), + "Qwen3.5 graph slot {slot_idx} exceeds capacity {}", + self.slot_states.len() + ); let dst = &mut self.slot_states[slot_idx]; - for (dst_layer, src_layer) in dst.layers.iter_mut().zip(src.layers.iter()) { - ctx.stream - .memcpy_dtod(&src_layer.state, &mut dst_layer.state) - .map_err(|e| anyhow::anyhow!("copy recurrent state to slot {slot_idx}: {e}"))?; - ctx.stream - .memcpy_dtod(&src_layer.conv_state.data, &mut dst_layer.conv_state.data) - .map_err(|e| anyhow::anyhow!("copy conv state to slot {slot_idx}: {e}"))?; - } - dst.seq_len = src.seq_len; + dst.copy_from(ctx, src) + .map_err(|e| anyhow::anyhow!("copy recurrent state to slot {slot_idx}: {e}"))?; + Ok(()) + } + + /// D2D copy slot `slot_idx` recurrent state into a standalone state. + pub(crate) fn copy_slot_to_state( + &self, + ctx: &DeviceContext, + slot_idx: usize, + dst: &mut RecurrentState, + ) -> Result<()> { + anyhow::ensure!( + slot_idx < self.slot_states.len(), + "Qwen3.5 graph slot {slot_idx} exceeds capacity {}", + self.slot_states.len() + ); + dst.copy_from(ctx, &self.slot_states[slot_idx]) + .map_err(|e| anyhow::anyhow!("copy recurrent slot {slot_idx} to state: {e}"))?; Ok(()) } + + /// D2D copy one graph slot's recurrent/conv state into another slot. + pub(crate) fn copy_slot_to_slot( + &mut self, + ctx: &DeviceContext, + src_slot_idx: usize, + dst_slot_idx: usize, + ) -> Result<()> { + anyhow::ensure!( + src_slot_idx < self.slot_states.len(), + "Qwen3.5 recurrent source slot {src_slot_idx} out of range {}", + self.slot_states.len() + ); + anyhow::ensure!( + dst_slot_idx < self.slot_states.len(), + "Qwen3.5 recurrent destination slot {dst_slot_idx} out of range {}", + self.slot_states.len() + ); + if src_slot_idx == dst_slot_idx { + return Ok(()); + } + if src_slot_idx < dst_slot_idx { + let (left, right) = self.slot_states.split_at_mut(dst_slot_idx); + let src = &left[src_slot_idx]; + let dst = &mut right[0]; + dst.copy_from(ctx, src) + } else { + let (left, right) = self.slot_states.split_at_mut(src_slot_idx); + let dst = &mut left[dst_slot_idx]; + let src = &right[0]; + dst.copy_from(ctx, src) + } + .map_err(|e| { + anyhow::anyhow!("copy Qwen3.5 recurrent slot {src_slot_idx} to {dst_slot_idx}: {e}") + }) + } } diff --git a/openinfer-qwen35-4b/src/decode_buffers.rs b/openinfer-qwen35-4b/src/decode_buffers.rs index 6d8c6daf..edc241f2 100644 --- a/openinfer-qwen35-4b/src/decode_buffers.rs +++ b/openinfer-qwen35-4b/src/decode_buffers.rs @@ -21,6 +21,7 @@ pub(crate) struct BatchDecodeBuffers35 { pub(crate) act_out: HiddenStates, pub(crate) mlp_out: HiddenStates, pub(crate) logits: HiddenStates, + pub(crate) captured_hidden: HiddenStates, // Full attention [dim, batch] pub(crate) q_full: HiddenStates, @@ -90,6 +91,7 @@ impl BatchDecodeBuffers35 { act_out: HiddenStates::zeros(ctx, config.intermediate_size, bs)?, mlp_out: HiddenStates::zeros(ctx, h, bs)?, logits: HiddenStates::zeros(ctx, config.vocab_size, bs)?, + captured_hidden: HiddenStates::zeros(ctx, h * config.num_hidden_layers, bs)?, q_full: HiddenStates::zeros(ctx, q_proj_dim, bs)?, q_attn: HiddenStates::zeros(ctx, q_dim, bs)?, @@ -136,6 +138,7 @@ impl BatchDecodeBuffers35 { self.act_out.seq_len = bs; self.mlp_out.seq_len = bs; self.logits.seq_len = bs; + self.captured_hidden.seq_len = bs; self.q_full.seq_len = bs; self.q_attn.seq_len = bs; diff --git a/openinfer-qwen35-4b/src/dflash.rs b/openinfer-qwen35-4b/src/dflash.rs new file mode 100644 index 00000000..4850141f --- /dev/null +++ b/openinfer-qwen35-4b/src/dflash.rs @@ -0,0 +1,597 @@ +use anyhow::{Context, Result}; + +use crate::dflash::config::{DFlashConfig, DFlashLayerType}; +use crate::dflash::state::DFlashContextScratch; +use crate::weights::Qwen35Model; +use openinfer_core::ops; +use openinfer_core::tensor::{DeviceContext, DeviceMatrix, DeviceVec, HiddenStates}; + +pub(crate) mod config; +mod loading; +mod reservation; +mod scratch; +mod state; + +pub(crate) use reservation::DFlashMemoryReservation; +pub(crate) use scratch::DFlashBatchScratch; +pub(crate) use state::DFlashRequestState; + +pub(crate) struct DFlashDraftModel { + config: DFlashConfig, + layers: Vec, + norm: DeviceVec, + hidden_norm: DeviceVec, + fc: DeviceMatrix, + cos_cache: DeviceVec, + sin_cache: DeviceVec, +} + +pub(crate) struct DFlashAttention { + pub(super) qkv_proj: DeviceMatrix, + pub(super) o_proj: DeviceMatrix, + pub(super) q_norm: DeviceVec, + pub(super) k_norm: DeviceVec, +} + +pub(crate) struct DFlashMlp { + pub(super) gate_up_proj: DeviceMatrix, + pub(super) down_proj: DeviceMatrix, +} + +pub(crate) struct DFlashBlock { + pub(super) layer_type: DFlashLayerType, + pub(super) input_layernorm: DeviceVec, + pub(super) attention: DFlashAttention, + pub(super) post_attention_layernorm: DeviceVec, + pub(super) mlp: DFlashMlp, +} + +impl DFlashDraftModel { + pub(crate) fn block_size(&self) -> usize { + self.config.block_size + } + + /// Largest sequence position the draft can cache. `validate_for_target` + /// guarantees this is `>=` the target's, but the draft's per-step in-fill + /// block writes `block_size` transient positions past the committed length, + /// so the usable context is `max_position_embeddings - block_size`. + pub(crate) fn max_position_embeddings(&self) -> usize { + self.config.max_position_embeddings + } + + pub(crate) fn mask_token_id(&self) -> u32 { + self.config.mask_token_id + } + + pub(crate) fn target_layer_ids(&self) -> &[usize] { + &self.config.target_layer_ids + } + + /// Anchor-first block layout (a checkpoint property, see + /// [`DFlashConfig::anchor_first`]) — drives both the verify span and the + /// draft-block slice start, independently of the markov head. + pub(crate) fn anchor_first(&self) -> bool { + self.config.anchor_first() + } + + pub(crate) fn verify_span(&self) -> usize { + if self.anchor_first() { + self.block_size() + 1 + } else { + self.block_size() + } + } + + pub(crate) fn tune_gemm_algos(&self, target: &Qwen35Model) -> Result<()> { + let ctx = target.device_ctx(); + let block_size = self.block_size().min(ops::GEMM_LT_MAX_N); + let hidden = self.config.hidden_size; + let q_dim = self.config.num_attention_heads * self.config.head_dim; + let kv_dim = self.config.num_key_value_heads * self.config.head_dim; + let context_dim = self.context_feature_dim(); + + let fc_samples = [(&self.fc, 0)]; + for n in 1..=block_size { + ops::gemm_lt_tune(ctx, &fc_samples, hidden, n)?; + } + + let kv_samples: Vec<_> = self + .layers + .iter() + .flat_map(|layer| { + [ + (&layer.attention.qkv_proj, q_dim), + (&layer.attention.qkv_proj, q_dim + kv_dim), + ] + }) + .collect(); + let min_tail_n = self.block_size() + 1; + let max_tail_n = (self.block_size() * 2).min(ops::GEMM_LT_MAX_N); + for n in min_tail_n..=max_tail_n { + ops::gemm_lt_tune(ctx, &kv_samples, kv_dim, n)?; + } + + log::info!( + "Qwen3.5 DFlash cublasLt tuned: fc M={} K={} N=1..{}, kv M={} K={} N={}..{}", + hidden, + context_dim, + block_size, + kv_dim, + hidden, + min_tail_n, + max_tail_n, + ); + Ok(()) + } + + /// Allocate the lane-level batched draft scratch once, sized for the whole + /// decode batch. The per-request `DFlashRequestState` no longer owns scratch. + pub(crate) fn new_batch_scratch( + &self, + ctx: &DeviceContext, + max_decode_batch_size: usize, + ) -> Result { + DFlashBatchScratch::new(ctx, &self.config, max_decode_batch_size) + } + + pub(crate) fn new_request_state( + &self, + ctx: &DeviceContext, + max_cache_len: usize, + ) -> Result { + anyhow::ensure!( + max_cache_len <= self.config.max_position_embeddings, + "DFlash request cache length {} exceeds max_position_embeddings {}", + max_cache_len, + self.config.max_position_embeddings + ); + let kv_dim = self.config.num_key_value_heads * self.config.head_dim; + DFlashRequestState::new( + ctx, + self.layers.len(), + kv_dim, + self.context_feature_dim(), + self.config.hidden_size, + self.config.block_size, + max_cache_len, + ) + } + + pub(crate) fn append_pending_context( + &self, + ctx: &DeviceContext, + state: &mut DFlashRequestState, + captured_hidden: &HiddenStates, + token_offset: usize, + token_count: usize, + ) -> Result<()> { + anyhow::ensure!(token_count > 0, "DFlash context append needs tokens"); + anyhow::ensure!( + captured_hidden.hidden_dim == self.context_feature_dim(), + "DFlash captured hidden dim {} does not match expected {}", + captured_hidden.hidden_dim, + self.context_feature_dim() + ); + anyhow::ensure!( + token_offset + token_count <= captured_hidden.seq_len, + "DFlash captured hidden token range exceeds source" + ); + let required_committed_len = state + .committed_len + .checked_add(state.pending_context.len) + .and_then(|len| len.checked_add(token_count)) + .and_then(|len| len.checked_add(self.block_size())) + .context("DFlash pending context cache length overflow")?; + anyhow::ensure!( + required_committed_len <= state.max_cache_len, + "DFlash pending context would exceed cache: committed={}, pending={}, append={}, block={}, max={}", + state.committed_len, + state.pending_context.len, + token_count, + self.block_size(), + state.max_cache_len + ); + state.pending_context.append_from( + ctx, + captured_hidden, + token_offset, + token_count, + state.max_cache_len, + )?; + Ok(()) + } + + /// Batched draft forward over all active requests at once. + /// + /// The *dense* ops (embedding, rms_norm, q / o / gate_up / down GEMMs, silu, + /// add, fused_add_rms_norm, logits) run ONCE over an `active_batch * + /// block_size` batched buffer. The *varlen* ops (context projection, tail + /// concat, k/v GEMMs, rope, KV copy, attention) still loop per request, + /// slicing each request's `block_size` rows at offset `i * block_size` in the + /// batched buffers — those are Step 2/3's job to batch via CUDA-kernel changes. + /// + /// Returns the batched logits (`active_batch * block_size` rows): request `i` + /// owns rows `[i * block_size, (i + 1) * block_size)`. + pub(crate) fn draft_logits_batched<'a>( + &self, + target: &Qwen35Model, + states: &mut [&mut DFlashRequestState], + current_tokens: &[u32], + scratch: &'a mut DFlashBatchScratch, + ) -> Result<&'a HiddenStates> { + let ctx = target.device_ctx(); + let active_batch = states.len(); + anyhow::ensure!( + active_batch > 0, + "DFlash batched draft needs active requests" + ); + anyhow::ensure!( + states.len() == current_tokens.len(), + "DFlash batched draft: {} states vs {} current tokens", + states.len(), + current_tokens.len() + ); + let block_size = self.block_size(); + let batch_block_rows = active_batch * block_size; + + // Each request's committed context length for this round; advancing + // `committed_len` is deferred until after the layer loop (the rope start + // positions and KV write offsets read the pre-advance value). + let mut context_lens = Vec::with_capacity(active_batch); + for (i, state) in states.iter().enumerate() { + let Some(context_len) = state.pending_context_len() else { + anyhow::bail!( + "DFlash draft requested before target hidden context is available (request slot {i})" + ); + }; + let tail_len = context_len + block_size; + anyhow::ensure!( + state.committed_len + tail_len <= state.max_cache_len, + "DFlash draft cache overflow: committed={}, tail={}, max={}", + state.committed_len, + tail_len, + state.max_cache_len + ); + context_lens.push(context_len); + } + + scratch.activate_dense(batch_block_rows); + + // Build the batched token id buffer: each request's block is + // [current_token, mask, mask, ...]. + scratch.block_token_ids_h[..batch_block_rows].fill(self.mask_token_id()); + for (i, ¤t_token) in current_tokens.iter().enumerate() { + scratch.block_token_ids_h[i * block_size] = current_token; + } + // token_ids_d holds `max_batch * block_size` ids; copy only the active + // prefix. The embedding kernel reads `out.seq_len = batch_block_rows` ids + // from the buffer start, so the active prefix is what it consumes. + let mut token_ids_dst = scratch.token_ids_d.slice_mut(..batch_block_rows); + ctx.stream.memcpy_htod( + &scratch.block_token_ids_h[..batch_block_rows], + &mut token_ids_dst, + )?; + target.get_embeddings_batch_into(&scratch.token_ids_d, &mut scratch.hidden)?; + + // Per-request context projection: varlen (each request's committed + // prefix differs), persisted in the request so every layer can read it. + for (i, state) in states.iter_mut().enumerate() { + let context_len = context_lens[i]; + state + .context + .ensure_capacity(ctx, self.config.hidden_size, context_len)?; + state.pending_context.activate_for_read(); + self.project_context_into(ctx, &state.pending_context.buffer, &mut state.context)?; + state.pending_context.clear(); + } + + let hidden_size = self.config.hidden_size; + let q_dim = self.config.num_attention_heads * self.config.head_dim; + let kv_dim = self.config.num_key_value_heads * self.config.head_dim; + let inter_dim = self.config.intermediate_size; + debug_assert_eq!(scratch.hidden.hidden_dim, hidden_size); + debug_assert_eq!(scratch.q_batch.hidden_dim, q_dim); + debug_assert_eq!(scratch.k_tail.hidden_dim, kv_dim); + debug_assert_eq!(scratch.gate_out.hidden_dim, inter_dim); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + // Dense: input layernorm over the whole batch. + ops::rms_norm_batch_into( + ctx, + &scratch.hidden, + &layer.input_layernorm, + self.config.rms_norm_eps, + &mut scratch.normed, + ); + + // Dense: Q projection over the whole batch (per-token, no cross-request + // mixing). Computed before the per-request loop reads `normed`, and + // before the post-attention norm overwrites it. + ops::gemm_rows_into( + ctx, + &layer.attention.qkv_proj, + 0, + q_dim, + &scratch.normed, + &mut scratch.q_batch, + ); + + // Per-request varlen attention: tail concat, k/v GEMMs, rope, KV copy, + // single-request prefill. Each request slices its `block_size` rows at + // offset `i * block_size` of the batched `normed`/`q_batch`/`attn_output`. + for (i, state) in states.iter_mut().enumerate() { + let context_len = context_lens[i]; + let tail_len = context_len + block_size; + let row_offset = i * block_size; + scratch.ensure_tail_capacity(ctx, &self.config, tail_len)?; + + // tail_input = [context_hidden(context_len) | normed_block(block_size)]. + ops::copy_hidden_token_range_into( + ctx, + &state.context.context_hidden, + 0, + &mut scratch.tail_input, + 0, + context_len, + )?; + ops::copy_hidden_token_range_into( + ctx, + &scratch.normed, + row_offset, + &mut scratch.tail_input, + context_len, + block_size, + )?; + + ops::gemm_rows_into( + ctx, + &layer.attention.qkv_proj, + q_dim, + kv_dim, + &scratch.tail_input, + &mut scratch.k_tail, + ); + ops::gemm_rows_into( + ctx, + &layer.attention.qkv_proj, + q_dim + kv_dim, + kv_dim, + &scratch.tail_input, + &mut scratch.v_tail, + ); + + ops::dflash_qk_norm_rope_into( + ctx, + &mut scratch.q_batch, + row_offset, + block_size, + &mut scratch.k_tail, + &layer.attention.q_norm, + &layer.attention.k_norm, + &self.cos_cache, + &self.sin_cache, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + state.committed_len + context_len, + state.committed_len, + self.config.rms_norm_eps, + )?; + + let cache = &mut state.layers[layer_idx]; + ops::copy_hidden_token_range_into( + ctx, + &scratch.k_tail, + 0, + &mut cache.k, + state.committed_len, + tail_len, + )?; + ops::copy_hidden_token_range_into( + ctx, + &scratch.v_tail, + 0, + &mut cache.v, + state.committed_len, + tail_len, + )?; + let cache_len = state.committed_len + tail_len; + match layer.layer_type { + DFlashLayerType::FullAttention => { + ops::single_prefill_nhd_noncausal_into( + ctx, + &scratch.q_batch, + row_offset, + block_size, + &cache.k, + &cache.v, + &mut scratch.attn_output, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + cache_len, + )?; + } + DFlashLayerType::SlidingAttention => { + // The z-lab Qwen3.5 DFlash checkpoint is trained with the + // draft block forwarded as non-causal (`is_causal=false`). + // `layer_types=sliding_attention` is still a checkpoint + // descriptor, but applying a causal sliding mask here + // collapses acceptance against the reference drafter. + ops::single_prefill_nhd_noncausal_into( + ctx, + &scratch.q_batch, + row_offset, + block_size, + &cache.k, + &cache.v, + &mut scratch.attn_output, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + cache_len, + )?; + } + } + } + + // Dense: o_proj + residual + post-attention norm + MLP over the batch. + ops::gemm_into( + ctx, + &layer.attention.o_proj, + &scratch.attn_output, + &mut scratch.o_buf, + ); + openinfer_kernels::ops::fused_add_rms_norm_round_batch_into( + ctx, + &mut scratch.hidden, + &scratch.o_buf, + &layer.post_attention_layernorm, + self.config.rms_norm_eps, + &mut scratch.normed, + )?; + + ops::gemm_rows_into( + ctx, + &layer.mlp.gate_up_proj, + 0, + inter_dim, + &scratch.normed, + &mut scratch.gate_out, + ); + ops::gemm_rows_into( + ctx, + &layer.mlp.gate_up_proj, + inter_dim, + inter_dim, + &scratch.normed, + &mut scratch.up_out, + ); + ops::silu_mul_batch_into( + ctx, + &scratch.gate_out, + &scratch.up_out, + &mut scratch.act_out, + )?; + ops::gemm_into( + ctx, + &layer.mlp.down_proj, + &scratch.act_out, + &mut scratch.o_buf, + ); + ops::add_batch_into( + ctx, + &scratch.hidden, + &scratch.o_buf, + &mut scratch.hidden_out, + )?; + std::mem::swap(&mut scratch.hidden, &mut scratch.hidden_out); + } + + for (i, state) in states.iter_mut().enumerate() { + state.committed_len += context_lens[i]; + } + self.compute_logits_with_target_head_into(target, scratch)?; + Ok(&scratch.logits) + } + + fn context_feature_dim(&self) -> usize { + self.config.hidden_size * self.target_layer_ids().len() + } + + fn project_context_into( + &self, + ctx: &DeviceContext, + context_features: &HiddenStates, + context: &mut DFlashContextScratch, + ) -> Result<()> { + ops::gemm_into( + ctx, + &self.fc, + context_features, + &mut context.context_projected, + ); + ops::rms_norm_batch_into( + ctx, + &context.context_projected, + &self.hidden_norm, + self.config.rms_norm_eps, + &mut context.context_hidden, + ); + Ok(()) + } + + fn compute_logits_with_target_head_into( + &self, + target: &Qwen35Model, + scratch: &mut DFlashBatchScratch, + ) -> Result<()> { + let ctx = target.device_ctx(); + ops::rms_norm_batch_into( + ctx, + &scratch.hidden, + &self.norm, + self.config.rms_norm_eps, + &mut scratch.logits_normed, + ); + ops::gemm_into( + ctx, + target.output_projection_tied(), + &scratch.logits_normed, + &mut scratch.logits, + ); + Ok(()) + } +} + +#[cfg(test)] +pub(crate) fn validate_dflash_config_for_target( + dflash_path: &str, + target_config: &crate::config::Config35, +) -> Result { + let config = DFlashConfig::from_file(dflash_path)?; + config.validate_for_target(target_config)?; + Ok(config) +} + +#[cfg(test)] +mod tests { + use super::validate_dflash_config_for_target; + use crate::config::Config35; + use std::path::Path; + + #[test] + fn env_dflash_config_matches_qwen35_target() { + let Ok(target_path) = std::env::var("OPENINFER_TEST_MODEL_PATH") else { + eprintln!("skipping DFlash config test; set OPENINFER_TEST_MODEL_PATH"); + return; + }; + let Ok(dflash_path) = std::env::var("OPENINFER_DFLASH_TEST_MODEL_PATH") else { + eprintln!("skipping DFlash config test; set OPENINFER_DFLASH_TEST_MODEL_PATH"); + return; + }; + if !Path::new(&target_path).join("config.json").exists() + || !Path::new(&dflash_path).join("config.json").exists() + { + eprintln!( + "skipping DFlash config test; set OPENINFER_TEST_MODEL_PATH and OPENINFER_DFLASH_TEST_MODEL_PATH" + ); + return; + } + + let target = Config35::from_file(&target_path).expect("target config"); + let dflash = validate_dflash_config_for_target(&dflash_path, &target) + .expect("DFlash config should match target"); + + assert!(dflash.block_size >= 2); + assert!(dflash.mask_token_id < target.vocab_size as u32); + assert_eq!(dflash.target_layer_ids.len(), dflash.num_hidden_layers); + + let reservation = + super::DFlashMemoryReservation::from_config(&dflash, /*max_decode_batch*/ 256); + assert!( + reservation.kv_bytes_per_token > 0 && reservation.fixed_bytes > 0, + "DFlash reservation must reserve both per-token and fixed memory" + ); + } +} diff --git a/openinfer-qwen35-4b/src/dflash/config.rs b/openinfer-qwen35-4b/src/dflash/config.rs new file mode 100644 index 00000000..f90e14f6 --- /dev/null +++ b/openinfer-qwen35-4b/src/dflash/config.rs @@ -0,0 +1,228 @@ +use anyhow::{Context, Result}; +use serde::Deserialize; +use std::fs; + +use crate::config::Config35; + +#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +pub(crate) enum DFlashLayerType { + SlidingAttention, + FullAttention, +} + +#[derive(Clone, Debug)] +pub(crate) struct DFlashConfig { + pub(crate) hidden_size: usize, + pub(crate) intermediate_size: usize, + pub(crate) num_hidden_layers: usize, + pub(crate) num_attention_heads: usize, + pub(crate) num_key_value_heads: usize, + pub(crate) num_target_layers: usize, + pub(crate) head_dim: usize, + pub(crate) vocab_size: usize, + pub(crate) rms_norm_eps: f32, + pub(crate) rope_theta: f32, + pub(crate) max_position_embeddings: usize, + pub(crate) block_size: usize, + pub(crate) mask_token_id: u32, + pub(crate) target_layer_ids: Vec, + pub(crate) layer_types: Vec, + pub(crate) sliding_window: usize, + pub(crate) anchor_first: bool, +} + +#[derive(Clone, Debug, Deserialize)] +struct DFlashInnerConfig { + block_size: usize, + mask_token_id: u32, + target_layer_ids: Vec, +} + +#[derive(Debug, Deserialize)] +struct RopeParameters { + rope_theta: f32, +} + +#[derive(Deserialize)] +struct RawDFlashConfig { + hidden_size: usize, + intermediate_size: usize, + num_hidden_layers: usize, + num_attention_heads: usize, + num_key_value_heads: usize, + num_target_layers: usize, + head_dim: usize, + vocab_size: usize, + rms_norm_eps: f32, + #[serde(default)] + rope_theta: Option, + #[serde(default)] + rope_parameters: Option, + #[serde(default = "default_max_position_embeddings")] + max_position_embeddings: usize, + #[serde(default)] + block_size: Option, + #[serde(default)] + dflash_config: Option, + #[serde(default)] + mask_token_id: Option, + #[serde(default)] + target_layer_ids: Option>, + #[serde(default)] + layer_types: Option>, + #[serde(default)] + sliding_window: Option, + #[serde(default)] + num_anchors: Option, +} + +fn default_max_position_embeddings() -> usize { + 40960 +} + +impl DFlashConfig { + pub(crate) fn from_file(model_path: &str) -> Result { + let config_path = format!("{}/config.json", model_path); + let content = fs::read_to_string(&config_path)?; + let raw: RawDFlashConfig = serde_json::from_str(&content)?; + let rope_theta = raw + .rope_theta + .or(raw.rope_parameters.map(|r| r.rope_theta)) + .context("Qwen3.5 DFlash config missing rope_theta / rope_parameters.rope_theta")?; + let (block_size, mask_token_id, target_layer_ids) = match raw.dflash_config { + Some(inner) => ( + inner.block_size, + inner.mask_token_id, + inner.target_layer_ids, + ), + None => ( + raw.block_size + .context("Qwen3.5 DFlash config missing block_size (no dflash_config block)")?, + raw.mask_token_id.context( + "Qwen3.5 DFlash config missing mask_token_id (no dflash_config block)", + )?, + raw.target_layer_ids.context( + "Qwen3.5 DFlash config missing target_layer_ids (no dflash_config block)", + )?, + ), + }; + let layer_types = raw + .layer_types + .unwrap_or_else(|| vec![DFlashLayerType::FullAttention; raw.num_hidden_layers]); + let sliding_window = raw.sliding_window.unwrap_or(raw.max_position_embeddings); + + Ok(Self { + hidden_size: raw.hidden_size, + intermediate_size: raw.intermediate_size, + num_hidden_layers: raw.num_hidden_layers, + num_attention_heads: raw.num_attention_heads, + num_key_value_heads: raw.num_key_value_heads, + num_target_layers: raw.num_target_layers, + head_dim: raw.head_dim, + vocab_size: raw.vocab_size, + rms_norm_eps: raw.rms_norm_eps, + rope_theta, + max_position_embeddings: raw.max_position_embeddings, + block_size, + mask_token_id, + target_layer_ids, + layer_types, + sliding_window, + anchor_first: raw.num_anchors.is_some(), + }) + } + + pub(crate) fn anchor_first(&self) -> bool { + self.anchor_first + } + + pub(crate) fn validate_for_target(&self, target: &Config35) -> Result<()> { + anyhow::ensure!( + self.hidden_size == target.hidden_size, + "Qwen3.5 DFlash hidden_size {} does not match target {}", + self.hidden_size, + target.hidden_size + ); + anyhow::ensure!( + self.num_target_layers == target.num_hidden_layers, + "Qwen3.5 DFlash num_target_layers {} does not match target layers {}", + self.num_target_layers, + target.num_hidden_layers + ); + anyhow::ensure!( + self.num_attention_heads > 0 + && self.num_key_value_heads > 0 + && self.head_dim > 0 + && self.num_attention_heads % self.num_key_value_heads == 0, + "Qwen3.5 DFlash attention geometry is invalid: heads={}, kv_heads={}, head_dim={}", + self.num_attention_heads, + self.num_key_value_heads, + self.head_dim + ); + anyhow::ensure!( + self.vocab_size == target.vocab_size, + "Qwen3.5 DFlash vocab_size {} does not match target {}", + self.vocab_size, + target.vocab_size + ); + anyhow::ensure!( + (self.rope_theta - target.rope_theta).abs() < f32::EPSILON, + "Qwen3.5 DFlash rope_theta {} does not match target {}", + self.rope_theta, + target.rope_theta + ); + anyhow::ensure!( + self.max_position_embeddings >= target.max_position_embeddings, + "Qwen3.5 DFlash max_position_embeddings {} is smaller than target {}", + self.max_position_embeddings, + target.max_position_embeddings + ); + anyhow::ensure!( + self.block_size >= 2, + "Qwen3.5 DFlash block_size must be >= 2, got {}", + self.block_size + ); + anyhow::ensure!( + self.mask_token_id < target.vocab_size as u32, + "Qwen3.5 DFlash mask_token_id {} is outside target vocab_size {}", + self.mask_token_id, + target.vocab_size + ); + anyhow::ensure!( + !self.target_layer_ids.is_empty(), + "Qwen3.5 DFlash target_layer_ids must not be empty" + ); + anyhow::ensure!( + self.target_layer_ids + .iter() + .all(|&layer| layer < target.num_hidden_layers), + "Qwen3.5 DFlash target_layer_ids must be within target layer count" + ); + anyhow::ensure!( + self.target_layer_ids + .windows(2) + .all(|pair| pair[0] < pair[1]), + "Qwen3.5 DFlash target_layer_ids must be strictly increasing" + ); + anyhow::ensure!( + self.layer_types.len() == self.num_hidden_layers, + "Qwen3.5 DFlash layer_types length {} does not match draft layers {}", + self.layer_types.len(), + self.num_hidden_layers + ); + if self + .layer_types + .iter() + .any(|layer| *layer == DFlashLayerType::SlidingAttention) + { + anyhow::ensure!( + self.sliding_window >= self.block_size, + "Qwen3.5 DFlash sliding_window {} must cover block_size {}", + self.sliding_window, + self.block_size + ); + } + Ok(()) + } +} diff --git a/openinfer-qwen35-4b/src/dflash/loading.rs b/openinfer-qwen35-4b/src/dflash/loading.rs new file mode 100644 index 00000000..47aba6c6 --- /dev/null +++ b/openinfer-qwen35-4b/src/dflash/loading.rs @@ -0,0 +1,145 @@ +use anyhow::{Context, Result}; +use log::debug; + +use crate::dflash::config::DFlashConfig; +use crate::dflash::{DFlashAttention, DFlashBlock, DFlashMlp}; +use crate::weights::Qwen35Model; +use openinfer_core::tensor::{DeviceContext, DeviceMatrix}; +use openinfer_core::weight_loader::{ + deserialize_shards, load_shard_info, load_tensor_1d, load_tensor_2d, mmap_shards, + precompute_rope, +}; + +use super::DFlashDraftModel; + +impl DFlashDraftModel { + pub(crate) fn from_safetensors_for_target( + ctx: &DeviceContext, + model_path: &str, + target: &Qwen35Model, + ) -> Result { + let config = DFlashConfig::from_file(model_path) + .with_context(|| format!("load DFlash config from {model_path}"))?; + config.validate_for_target(target.config())?; + + let (shard_paths, weight_map) = load_shard_info(model_path)?; + debug!( + "Loading DFlash drafter from {model_path}: {} shard(s)", + shard_paths.len() + ); + let mmaps = mmap_shards(&shard_paths)?; + let shards = deserialize_shards(&mmaps)?; + + let mut layers = Vec::with_capacity(config.num_hidden_layers); + for layer_idx in 0..config.num_hidden_layers { + let prefix = format!("layers.{layer_idx}"); + + let q_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.q_proj.weight"), + )?; + let k_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.k_proj.weight"), + )?; + let v_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.v_proj.weight"), + )?; + let qkv_proj = DeviceMatrix::vstack(ctx, &[&q_proj, &k_proj, &v_proj])?; + drop(q_proj); + drop(k_proj); + drop(v_proj); + + let gate_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.mlp.gate_proj.weight"), + )?; + let up_proj = load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.mlp.up_proj.weight"), + )?; + let gate_up_proj = DeviceMatrix::vstack(ctx, &[&gate_proj, &up_proj])?; + drop(gate_proj); + drop(up_proj); + + layers.push(DFlashBlock { + layer_type: config.layer_types[layer_idx], + input_layernorm: load_tensor_1d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.input_layernorm.weight"), + )?, + attention: DFlashAttention { + qkv_proj, + o_proj: load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.o_proj.weight"), + )?, + q_norm: load_tensor_1d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.q_norm.weight"), + )?, + k_norm: load_tensor_1d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.self_attn.k_norm.weight"), + )?, + }, + post_attention_layernorm: load_tensor_1d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.post_attention_layernorm.weight"), + )?, + mlp: DFlashMlp { + gate_up_proj, + down_proj: load_tensor_2d( + ctx, + &shards, + &weight_map, + &format!("{prefix}.mlp.down_proj.weight"), + )?, + }, + }); + } + + let norm = load_tensor_1d(ctx, &shards, &weight_map, "norm.weight")?; + let hidden_norm = load_tensor_1d(ctx, &shards, &weight_map, "hidden_norm.weight")?; + let fc = load_tensor_2d(ctx, &shards, &weight_map, "fc.weight")?; + + let (cos_cache, sin_cache) = precompute_rope( + ctx, + config.head_dim, + config.max_position_embeddings, + config.rope_theta, + )?; + ctx.sync()?; + + Ok(Self { + config, + layers, + norm, + hidden_norm, + fc, + cos_cache, + sin_cache, + }) + } +} diff --git a/openinfer-qwen35-4b/src/dflash/reservation.rs b/openinfer-qwen35-4b/src/dflash/reservation.rs new file mode 100644 index 00000000..2f1e9eb6 --- /dev/null +++ b/openinfer-qwen35-4b/src/dflash/reservation.rs @@ -0,0 +1,84 @@ +use anyhow::Result; + +use crate::dflash::config::DFlashConfig; + +/// GPU memory DFlash needs on top of the target KV pool, derived from the draft +/// config so the KV budget can reserve it *before* the draft model loads (the +/// draft buffers live outside the paged `KvCacheManager`). Split by how it scales: +/// +/// - `kv_bytes_per_token` scales with the KV pool (billed by shrinking the target +/// block count): the draft's own KV cache plus the per-request context-projection +/// and pending-context buffers, which currently persist at prompt length per +/// request (see `dflash-speculative-decoding.md` — collapsing that persistence +/// is a tracked follow-up that would shrink this term to the draft KV alone). +/// - `fixed_bytes` does not scale with the pool (billed via the memory margin): +/// the draft weights plus the lane-level batched scratch sized for the whole +/// decode batch. +/// +/// The scratch accounting intentionally stays conservative: the live draft path +/// allocates one lane-level `DFlashBatchScratch` for dense block rows, while the +/// durable request cache still keeps per-request prefix and tail buffers. Keeping +/// `tail_scratch` in both the pool-scaled term and the one-block headroom term +/// slightly over-reserves memory, but it prevents late OOM when admission accepts +/// a request close to the context limit. +pub(crate) struct DFlashMemoryReservation { + pub(crate) kv_bytes_per_token: usize, + pub(crate) fixed_bytes: usize, +} + +impl DFlashMemoryReservation { + pub(crate) fn from_path(draft_path: &str, max_decode_batch_size: usize) -> Result { + let config = DFlashConfig::from_file(draft_path)?; + Ok(Self::from_config(&config, max_decode_batch_size)) + } + + pub(crate) fn from_config(config: &DFlashConfig, max_decode_batch_size: usize) -> Self { + const BF16: usize = 2; + let hidden = config.hidden_size; + let kv_dim = config.num_key_value_heads * config.head_dim; + let q_dim = config.num_attention_heads * config.head_dim; + let inter = config.intermediate_size; + let capture_layers = config.target_layer_ids.len(); + + // Per-sequence-token, pool-scaling buffers. + let draft_kv = config.num_hidden_layers * 2 * kv_dim * BF16; // DFlashLayerCache k+v + // Scratch split by what it tracks: `context_*` grows with the committed + // prefix; `tail_*` (tail_input + k_tail + v_tail) grows with the in-fill + // tail, which is one block past the prefix. + let context_scratch = 2 * hidden * BF16; // context_projected + context_hidden + let tail_scratch = (hidden + 2 * kv_dim) * BF16; // tail_input + k_tail + v_tail + let pending = hidden * capture_layers * BF16; // context_feature_dim + let kv_bytes_per_token = draft_kv + context_scratch + tail_scratch + pending; + + // Lane-level batched dense scratch: every dense buffer is sized for the + // whole decode batch (`max_batch * block_size` rows), allocated once. + // Same total magnitude as the old per-request scratch summed over the + // batch, but now one contiguous allocation. + let dense_scratch_per_block_row = + BF16 * (config.vocab_size + 5 * hidden + 2 * q_dim + 3 * inter); + let scratch_total = dense_scratch_per_block_row * config.block_size * max_decode_batch_size; + + // Draft weights (5 transformer layers + the context projection), +10% slack + // for norms, rope caches, and allocator alignment. + let per_layer = BF16 + * (hidden * (q_dim + 2 * kv_dim) // qkv_proj + + q_dim * hidden // o_proj + + hidden * 2 * inter // gate_up_proj + + inter * hidden); // down_proj + let fc = BF16 * hidden * (hidden * capture_layers); // context projection + let weights = per_layer * config.num_hidden_layers + fc; + let weights = weights + weights / 10; + + // The durable draft KV and the tail scratch are sized to `context + + // block_size` — one in-fill block past the lifetime the KV pool reserves + // for the request. The per-token term bills only the pool's tokens, so + // reserve that one-block headroom per concurrently decoding request to + // keep the reservation an upper bound. + let block_headroom = max_decode_batch_size * config.block_size * (draft_kv + tail_scratch); + + Self { + kv_bytes_per_token, + fixed_bytes: weights + scratch_total + block_headroom, + } + } +} diff --git a/openinfer-qwen35-4b/src/dflash/scratch.rs b/openinfer-qwen35-4b/src/dflash/scratch.rs new file mode 100644 index 00000000..12397e48 --- /dev/null +++ b/openinfer-qwen35-4b/src/dflash/scratch.rs @@ -0,0 +1,122 @@ +use anyhow::Result; +use cudarc::driver::CudaSlice; + +use crate::dflash::config::DFlashConfig; +use openinfer_core::tensor::{DeviceContext, HiddenStates}; + +/// Lane-level batched draft scratch, allocated once for the whole decode batch. +/// +/// Dense buffers (`hidden`, `normed`, `q_batch`, `attn_output`, the MLP buffers, +/// and `logits`) hold `max_batch * block_size` rows so the GEMM / rms_norm / +/// silu / add / logits / embedding ops run once over the batched buffer. The +/// varlen tail buffers (`tail_input`, `k_tail`, `v_tail`) stay sized for a single +/// request and are reused inside the per-request loop. +pub(crate) struct DFlashBatchScratch { + max_batch_block_rows: usize, + max_tail_len: usize, + pub(super) block_token_ids_h: Vec, + pub(super) token_ids_d: CudaSlice, + pub(super) hidden: HiddenStates, + pub(super) hidden_out: HiddenStates, + pub(super) normed: HiddenStates, + pub(super) q_batch: HiddenStates, + pub(super) attn_output: HiddenStates, + pub(super) o_buf: HiddenStates, + pub(super) gate_out: HiddenStates, + pub(super) up_out: HiddenStates, + pub(super) act_out: HiddenStates, + pub(super) logits_normed: HiddenStates, + pub(super) logits: HiddenStates, + // Shared single-request varlen tail scratch (reused inside the per-request loop). + pub(super) tail_input: HiddenStates, + pub(super) k_tail: HiddenStates, + pub(super) v_tail: HiddenStates, +} + +impl DFlashBatchScratch { + pub(crate) fn new( + ctx: &DeviceContext, + config: &DFlashConfig, + max_decode_batch_size: usize, + ) -> Result { + anyhow::ensure!( + max_decode_batch_size > 0, + "DFlash batch scratch needs a non-zero batch size" + ); + let block_size = config.block_size; + let hidden_size = config.hidden_size; + let q_dim = config.num_attention_heads * config.head_dim; + let kv_dim = config.num_key_value_heads * config.head_dim; + let inter_dim = config.intermediate_size; + // Dense buffers span the whole decode batch so the dense ops run once. + let batch_rows = block_size * max_decode_batch_size; + // The shared varlen tail starts at one block (no committed context yet) + // and grows on demand via `ensure_tail_capacity`. + let tail_capacity = block_size; + Ok(Self { + max_batch_block_rows: batch_rows, + max_tail_len: tail_capacity, + block_token_ids_h: vec![config.mask_token_id; batch_rows], + token_ids_d: ctx.stream.alloc_zeros(batch_rows)?, + hidden: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + hidden_out: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + normed: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + q_batch: HiddenStates::zeros(ctx, q_dim, batch_rows)?, + attn_output: HiddenStates::zeros(ctx, q_dim, batch_rows)?, + o_buf: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + gate_out: HiddenStates::zeros(ctx, inter_dim, batch_rows)?, + up_out: HiddenStates::zeros(ctx, inter_dim, batch_rows)?, + act_out: HiddenStates::zeros(ctx, inter_dim, batch_rows)?, + logits_normed: HiddenStates::zeros(ctx, hidden_size, batch_rows)?, + logits: HiddenStates::zeros(ctx, config.vocab_size, batch_rows)?, + tail_input: HiddenStates::zeros(ctx, hidden_size, tail_capacity)?, + k_tail: HiddenStates::zeros(ctx, kv_dim, tail_capacity)?, + v_tail: HiddenStates::zeros(ctx, kv_dim, tail_capacity)?, + }) + } + + /// Point every dense buffer at the active `batch_block_rows = active_batch * + /// block_size` prefix. Allocated for the max decode batch, so this only + /// shrinks `seq_len`; it never reallocates. + pub(super) fn activate_dense(&mut self, batch_block_rows: usize) { + assert!( + batch_block_rows <= self.max_batch_block_rows, + "DFlash batched draft {} block rows exceeds scratch capacity {}", + batch_block_rows, + self.max_batch_block_rows + ); + self.hidden.seq_len = batch_block_rows; + self.hidden_out.seq_len = batch_block_rows; + self.normed.seq_len = batch_block_rows; + self.q_batch.seq_len = batch_block_rows; + self.attn_output.seq_len = batch_block_rows; + self.o_buf.seq_len = batch_block_rows; + self.gate_out.seq_len = batch_block_rows; + self.up_out.seq_len = batch_block_rows; + self.act_out.seq_len = batch_block_rows; + self.logits_normed.seq_len = batch_block_rows; + self.logits.seq_len = batch_block_rows; + } + + /// Size the shared varlen tail buffers for one request's `tail_len = + /// context_len + block_size`, growing the allocation if needed. + pub(super) fn ensure_tail_capacity( + &mut self, + ctx: &DeviceContext, + config: &DFlashConfig, + tail_len: usize, + ) -> Result<()> { + if tail_len > self.max_tail_len { + let hidden_size = config.hidden_size; + let kv_dim = config.num_key_value_heads * config.head_dim; + self.tail_input = HiddenStates::zeros(ctx, hidden_size, tail_len)?; + self.k_tail = HiddenStates::zeros(ctx, kv_dim, tail_len)?; + self.v_tail = HiddenStates::zeros(ctx, kv_dim, tail_len)?; + self.max_tail_len = tail_len; + } + self.tail_input.seq_len = tail_len; + self.k_tail.seq_len = tail_len; + self.v_tail.seq_len = tail_len; + Ok(()) + } +} diff --git a/openinfer-qwen35-4b/src/dflash/state.rs b/openinfer-qwen35-4b/src/dflash/state.rs new file mode 100644 index 00000000..d2a9f962 --- /dev/null +++ b/openinfer-qwen35-4b/src/dflash/state.rs @@ -0,0 +1,184 @@ +use anyhow::{Context, Result}; + +use openinfer_core::ops; +use openinfer_core::tensor::{DeviceContext, HiddenStates}; + +pub(crate) struct DFlashRequestState { + pub(super) layers: Vec, + pub(super) pending_context: DFlashPendingContext, + /// Projected target context for the current draft round. Computed once from + /// `pending_context` and read by every layer's tail concat, so it lives with + /// the request (the batched scratch only holds one request's varlen tail). + pub(super) context: DFlashContextScratch, + pub(super) committed_len: usize, + pub(super) max_cache_len: usize, +} + +pub(super) struct DFlashLayerCache { + pub(super) k: HiddenStates, + pub(super) v: HiddenStates, +} + +pub(super) struct DFlashPendingContext { + pub(super) buffer: HiddenStates, + pub(super) len: usize, + capacity: usize, +} + +/// Per-request projected context. The fc projection + hidden_norm turn the +/// captured target hidden context into draft hidden space once per draft round; +/// every layer's tail concat reads `context_hidden`, so it must persist across +/// the layer loop and therefore lives in the request (not the shared scratch). +pub(super) struct DFlashContextScratch { + max_context_len: usize, + pub(super) context_projected: HiddenStates, + pub(super) context_hidden: HiddenStates, +} + +impl DFlashRequestState { + pub(crate) fn new( + ctx: &DeviceContext, + num_layers: usize, + kv_dim: usize, + context_feature_dim: usize, + hidden_size: usize, + block_size: usize, + max_cache_len: usize, + ) -> Result { + let mut layers = Vec::with_capacity(num_layers); + for _ in 0..num_layers { + layers.push(DFlashLayerCache { + k: HiddenStates::zeros(ctx, kv_dim, max_cache_len)?, + v: HiddenStates::zeros(ctx, kv_dim, max_cache_len)?, + }); + } + Ok(Self { + layers, + pending_context: DFlashPendingContext::new( + ctx, + context_feature_dim, + block_size.min(max_cache_len), + )?, + context: DFlashContextScratch::new(ctx, hidden_size, block_size)?, + committed_len: 0, + max_cache_len, + }) + } + + pub(crate) fn pending_context_len(&self) -> Option { + (self.pending_context.len > 0).then_some(self.pending_context.len) + } +} + +impl DFlashPendingContext { + fn new(ctx: &DeviceContext, hidden_dim: usize, capacity: usize) -> Result { + anyhow::ensure!( + capacity > 0, + "DFlash pending context capacity must be non-zero" + ); + let mut buffer = HiddenStates::zeros(ctx, hidden_dim, capacity)?; + buffer.seq_len = 0; + Ok(Self { + buffer, + len: 0, + capacity, + }) + } + + pub(super) fn append_from( + &mut self, + ctx: &DeviceContext, + src: &HiddenStates, + src_token_offset: usize, + token_count: usize, + max_capacity: usize, + ) -> Result<()> { + let required_len = self + .len + .checked_add(token_count) + .context("DFlash pending context length overflow")?; + anyhow::ensure!( + required_len <= max_capacity, + "DFlash pending context length {} exceeds request capacity {}", + required_len, + max_capacity + ); + self.ensure_capacity(ctx, required_len, max_capacity)?; + self.buffer.seq_len = self.capacity; + ops::copy_hidden_token_range_into( + ctx, + src, + src_token_offset, + &mut self.buffer, + self.len, + token_count, + )?; + self.len = required_len; + self.buffer.seq_len = self.len; + Ok(()) + } + + fn ensure_capacity( + &mut self, + ctx: &DeviceContext, + required_len: usize, + max_capacity: usize, + ) -> Result<()> { + if required_len <= self.capacity { + return Ok(()); + } + let doubled = self + .capacity + .checked_mul(2) + .context("DFlash pending context capacity overflow")?; + let new_capacity = required_len.max(doubled).min(max_capacity); + anyhow::ensure!( + new_capacity >= required_len, + "DFlash pending context capacity {} cannot fit {} tokens", + new_capacity, + required_len + ); + let mut next = HiddenStates::zeros(ctx, self.buffer.hidden_dim, new_capacity)?; + if self.len > 0 { + self.buffer.seq_len = self.capacity; + ops::copy_hidden_token_range_into(ctx, &self.buffer, 0, &mut next, 0, self.len)?; + } + next.seq_len = self.len; + self.buffer = next; + self.capacity = new_capacity; + Ok(()) + } + + pub(super) fn activate_for_read(&mut self) { + self.buffer.seq_len = self.len; + } + + pub(super) fn clear(&mut self) { + self.len = 0; + self.buffer.seq_len = 0; + } +} + +impl DFlashContextScratch { + fn new(ctx: &DeviceContext, hidden_size: usize, max_context_len: usize) -> Result { + Ok(Self { + max_context_len, + context_projected: HiddenStates::zeros(ctx, hidden_size, max_context_len)?, + context_hidden: HiddenStates::zeros(ctx, hidden_size, max_context_len)?, + }) + } + + pub(super) fn ensure_capacity( + &mut self, + ctx: &DeviceContext, + hidden_size: usize, + context_len: usize, + ) -> Result<()> { + if context_len > self.max_context_len { + *self = Self::new(ctx, hidden_size, context_len)?; + } + self.context_projected.seq_len = context_len; + self.context_hidden.seq_len = context_len; + Ok(()) + } +} diff --git a/openinfer-qwen35-4b/src/executor.rs b/openinfer-qwen35-4b/src/executor.rs index d4244ca1..98410a56 100644 --- a/openinfer-qwen35-4b/src/executor.rs +++ b/openinfer-qwen35-4b/src/executor.rs @@ -97,16 +97,16 @@ pub struct DecodeResult { pub requests: Vec, } -struct ActiveRequest { - request_id: RequestId, - kv: KvState, - graph_slot_idx: usize, +pub(crate) struct ActiveRequest { + pub(crate) request_id: RequestId, + pub(crate) kv: KvState, + pub(crate) graph_slot_idx: usize, } pub struct Qwen35Executor { - model: Qwen35Model, - graph_state: BatchDecodeGraphState, - active: Vec, + pub(crate) model: Qwen35Model, + pub(crate) graph_state: BatchDecodeGraphState, + pub(crate) active: Vec, } impl Qwen35Executor { @@ -300,40 +300,34 @@ impl Qwen35Executor { self.active[idx].graph_slot_idx, last ); - for layer_idx in 0..self.graph_state.slot_states[last].layers.len() { - let (src_part, dst_part) = if idx < last { - let (left, right) = self.graph_state.slot_states.split_at_mut(last); - ( - &right[0].layers[layer_idx], - &mut left[idx].layers[layer_idx], - ) - } else { - unreachable!("idx < active.len() <= last"); - }; - self.model - .device_ctx() - .stream - .memcpy_dtod(&src_part.state, &mut dst_part.state) - .map_err(|e| { - anyhow::anyhow!("compact Qwen3.5 logits executor state copy failed: {e}") - })?; - self.model - .device_ctx() - .stream - .memcpy_dtod(&src_part.conv_state.data, &mut dst_part.conv_state.data) - .map_err(|e| { - anyhow::anyhow!( - "compact Qwen3.5 logits executor conv_state copy failed: {e}" - ) - })?; - } - self.graph_state.slot_states[idx].seq_len = self.graph_state.slot_states[last].seq_len; + self.graph_state + .copy_slot_to_slot(self.model.device_ctx(), last, idx)?; self.active[idx].graph_slot_idx = idx; } Ok(()) } } +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ExecutorStateSummary { + pub request_id: RequestId, + pub kv_seq_len: usize, + pub recurrent_seq_len: usize, +} + +impl Qwen35Executor { + pub fn debug_state_summary(&self) -> Vec { + self.active + .iter() + .map(|active| ExecutorStateSummary { + request_id: active.request_id, + kv_seq_len: active.kv.seq_len(), + recurrent_seq_len: self.graph_state.slot_states[active.graph_slot_idx].seq_len, + }) + .collect() + } +} + fn select_default_tokens_from_logits( model: &Qwen35Model, logits: &HiddenStates, diff --git a/openinfer-qwen35-4b/src/kernel_plan.rs b/openinfer-qwen35-4b/src/kernel_plan.rs index 6a00321e..12d7195f 100644 --- a/openinfer-qwen35-4b/src/kernel_plan.rs +++ b/openinfer-qwen35-4b/src/kernel_plan.rs @@ -57,7 +57,7 @@ pub static KERNEL_PLAN: KernelPlan = KernelPlan { // shared prefill prologue KernelOp { id: "embedding_prefill", - rust: "prefill::prefill_last_hidden -> ops::embedding_batch", + rust: "prefill::prefill_chunk_forward_with_capture -> ops::embedding_batch", backend: "CUDA", notes: "prompt tokens to hidden states", }, diff --git a/openinfer-qwen35-4b/src/lib.rs b/openinfer-qwen35-4b/src/lib.rs index cda0d183..22f542e6 100644 --- a/openinfer-qwen35-4b/src/lib.rs +++ b/openinfer-qwen35-4b/src/lib.rs @@ -10,6 +10,7 @@ mod batch_decode; pub(crate) mod batch_decode_graph; pub(crate) mod config; mod decode_buffers; +mod dflash; mod executor; mod ffi; mod logprobs; @@ -19,10 +20,13 @@ pub mod prefill_buffers; pub(crate) mod recurrent; pub(crate) mod recurrent_state; mod scheduler; +pub mod speculative; mod unified_forward; +mod verify_buffers; mod weights; use std::path::Path; +use std::path::PathBuf; use anyhow::{Result, anyhow}; use openinfer_core::engine::{EngineHandle, EngineLoadOptions, EpBackend}; @@ -37,10 +41,14 @@ pub use scheduler::DEFAULT_MAX_PREFILL_TOKENS; pub mod runtime { pub use crate::batch_decode_graph::MAX_BATCH; pub use crate::executor::{ - DecodePlan, DecodeRequestResult, DecodeResult, DecodeStepItem, PrefillPlan, - PrefillRequestResult, PrefillResult, PrefillStepItem, Qwen35Executor, RequestId, + DecodePlan, DecodeRequestResult, DecodeResult, DecodeStepItem, ExecutorStateSummary, + PrefillPlan, PrefillRequestResult, PrefillResult, PrefillStepItem, Qwen35Executor, + RequestId, }; pub use crate::scheduler::start_with_capacity; + pub use crate::speculative::{ + VerifiedToken, VerifyPlan, VerifyRequestResult, VerifyResult, VerifyStepItem, + }; pub use crate::weights::Qwen35Model; } @@ -68,8 +76,9 @@ pub fn launch( device_ordinal: usize, cuda_graph: bool, max_prefill_tokens: usize, + dflash_draft_model_path: Option, ) -> Result { - start_engine_with_capacity( + start_engine_with_capacity_and_dflash( model_path, EngineLoadOptions { enable_cuda_graph: cuda_graph, @@ -81,6 +90,7 @@ pub fn launch( }, batch_decode_graph::MAX_BATCH, max_prefill_tokens, + dflash_draft_model_path, ) } @@ -89,6 +99,16 @@ pub fn start_engine_with_capacity( options: EngineLoadOptions, max_batch: usize, max_prefill_tokens: usize, +) -> Result { + start_engine_with_capacity_and_dflash(model_path, options, max_batch, max_prefill_tokens, None) +} + +pub fn start_engine_with_capacity_and_dflash( + model_path: &Path, + options: EngineLoadOptions, + max_batch: usize, + max_prefill_tokens: usize, + dflash_draft_model_path: Option, ) -> Result { let EngineLoadOptions { enable_cuda_graph, @@ -109,10 +129,27 @@ pub fn start_engine_with_capacity( let model_path = model_path .to_str() .ok_or_else(|| anyhow!("model path must be valid UTF-8"))?; - let model = weights::Qwen35Model::from_safetensors_with_device_options( + let dflash_path = dflash_draft_model_path + .as_ref() + .map(|path| { + path.to_str() + .ok_or_else(|| anyhow!("DFlash draft model path must be valid UTF-8")) + }) + .transpose()?; + let dflash_reservation = dflash_path + .map(|path| dflash::DFlashMemoryReservation::from_path(path, max_batch)) + .transpose()?; + let model = weights::Qwen35Model::from_safetensors_with_device_options_and_reservation( model_path, enable_cuda_graph, device_ordinal, + dflash_reservation.as_ref(), )?; - scheduler::start_with_capacity(model, seed, max_batch, max_prefill_tokens) + scheduler::start_with_capacity_and_dflash( + model, + seed, + max_batch, + max_prefill_tokens, + dflash_draft_model_path, + ) } diff --git a/openinfer-qwen35-4b/src/ops.rs b/openinfer-qwen35-4b/src/ops.rs index 120688df..f04463b9 100644 --- a/openinfer-qwen35-4b/src/ops.rs +++ b/openinfer-qwen35-4b/src/ops.rs @@ -2,10 +2,10 @@ pub(crate) use openinfer_core::ops::PrefillPagedPlan; pub(crate) use openinfer_core::ops::{ - GEMM_LT_MAX_N, add_batch, add_batch_into, embedding_batch, extract_vec, extract_vec_into, gemm, - gemm_into, gemm_lt_tune, paged_attention_batch_decode_hd256_into, - qk_norm_partial_rope_batched_decode_hd256_into, rms_norm_gated_batch_into, - silu_mul_fused_batch_into, write_vec_into, + GEMM_LT_MAX_N, add_batch, add_batch_into, copy_hidden_rows_into, copy_hidden_token_range_into, + embedding_batch, extract_vec, extract_vec_into, gemm, gemm_into, gemm_lt_tune, + paged_attention_batch_decode_hd256_into, qk_norm_partial_rope_batched_decode_hd256_into, + rms_norm_gated_batch_into, silu_mul_fused_batch_into, write_vec_into, }; pub use openinfer_core::ops::{rms_norm_batch_offset_into, rms_norm_offset_into}; pub use recurrent::gated_delta_rule_prefill_chunkwise_into; diff --git a/openinfer-qwen35-4b/src/prefill.rs b/openinfer-qwen35-4b/src/prefill.rs index a8af94f6..d0444bb4 100644 --- a/openinfer-qwen35-4b/src/prefill.rs +++ b/openinfer-qwen35-4b/src/prefill.rs @@ -6,7 +6,7 @@ use cudarc::driver::{CudaSlice, DevicePtr, DevicePtrMut}; /// This is not an admission cap. Actual prompt admission is governed by the /// paged KV pool, RoPE cache coverage, and allocation success. Prompts longer /// than this are handled by chunking prefill at `PREFILL_CHUNK_LEN` rather than -/// being rejected (see `prefill_last_hidden`). +/// being rejected (see `prefill_chunk_forward_with_capture`). pub(crate) const SCRATCH_ESTIMATE_SEQ: usize = 20_000; /// Maximum number of tokens processed in a single prefill forward pass. @@ -20,6 +20,7 @@ const HEAD_DIM: usize = 256; use super::prefill_buffers::GdrChunkwiseScratch35; use super::recurrent_state::RecurrentState; +use super::verify_buffers::VerifyBuffers35; use super::weights::{ FullAttentionLayer, LayerKind, LinearAttentionLayer, Qwen35Model, TransformerBlock35, }; @@ -45,45 +46,6 @@ fn checked_prefill_end_pos( } impl Qwen35Model { - pub(super) fn prefill_last_hidden( - &self, - token_ids: &[u32], - kv_state: &mut KvState, - recurrent: &mut RecurrentState, - ) -> Result { - let seq_len = token_ids.len(); - anyhow::ensure!( - seq_len > 0, - "Qwen3.5 prefill_last_hidden requires at least one token" - ); - let c = &self.config; - - // Validate the full target range up front (position overflow + RoPE cache - // coverage) so an out-of-range prompt is rejected before any chunk mutates - // the KV / recurrent state, rather than failing partway through. - let base_pos = kv_state.seq_len(); - let end_pos = checked_prefill_end_pos(base_pos, seq_len, c.max_position_embeddings)?; - self.ensure_rope_cache_covers(end_pos)?; - - // Run prefill in serial chunks of at most `PREFILL_CHUNK_LEN` tokens. Each - // chunk advances the paged KV and linear-attention recurrent/conv state in - // place, so the next chunk continues from the previous one. This caps the - // per-pass GDR scratch (which grows with the pass length) at the budget - // reserved at startup, so prompts longer than one chunk prefill without OOM. - let mut hidden_batch: Option = None; - for chunk in token_ids.chunks(PREFILL_CHUNK_LEN) { - // Free the previous chunk's hidden states before allocating the next - // chunk's scratch so peak memory stays within one chunk's reservation. - drop(hidden_batch.take()); - hidden_batch = Some(self.prefill_chunk_forward(chunk, kv_state, recurrent)?); - } - // `seq_len > 0` guarantees at least one chunk produced hidden states. - let hidden_batch = hidden_batch.expect("prefill produced no chunk despite seq_len > 0"); - - // Last-token logic runs once, on the final chunk's output. - ops::extract_vec(&self.ctx, &hidden_batch, hidden_batch.seq_len - 1) - } - pub(super) fn batch_last_hidden_logits( &self, last_hiddens: &[DeviceVec], @@ -115,6 +77,62 @@ impl Qwen35Model { Ok(logits) } + pub(crate) fn hidden_logits(&self, hidden_batch: &HiddenStates) -> Result { + anyhow::ensure!( + hidden_batch.seq_len > 0, + "Qwen3.5 hidden_logits requires at least one row" + ); + let mut normed = + HiddenStates::zeros(&self.ctx, hidden_batch.hidden_dim, hidden_batch.seq_len)?; + ops::rms_norm_batch_offset_into( + &self.ctx, + hidden_batch, + &self.norm, + self.config.rms_norm_eps, + &mut normed, + )?; + ops::gemm(&self.ctx, &self.embed_tokens, &normed) + } + + pub(crate) fn prefill_logits_all( + &self, + token_ids: &[u32], + kv_state: &mut KvState, + recurrent: &mut RecurrentState, + ) -> Result { + anyhow::ensure!( + token_ids.len() <= PREFILL_CHUNK_LEN, + "Qwen3.5 all-position prefill logits only supports one chunk: requested {}, max {}", + token_ids.len(), + PREFILL_CHUNK_LEN + ); + let hidden = self.prefill_chunk_forward(token_ids, kv_state, recurrent)?; + self.hidden_logits(&hidden) + } + + #[allow(dead_code)] + pub(crate) fn prefill_logits_all_with_capture( + &self, + token_ids: &[u32], + kv_state: &mut KvState, + recurrent: &mut RecurrentState, + capture_layer_ids: Option<&[usize]>, + ) -> Result<(HiddenStates, Option)> { + anyhow::ensure!( + token_ids.len() <= PREFILL_CHUNK_LEN, + "Qwen3.5 all-position prefill logits only supports one chunk: requested {}, max {}", + token_ids.len(), + PREFILL_CHUNK_LEN + ); + let (hidden, captured) = self.prefill_chunk_forward_with_capture( + token_ids, + kv_state, + recurrent, + capture_layer_ids, + )?; + Ok((self.hidden_logits(&hidden)?, captured)) + } + /// Forward one prefill chunk through all layers, advancing the paged KV state /// and the linear-attention recurrent/conv state in place. /// @@ -127,6 +145,17 @@ impl Qwen35Model { kv_state: &mut KvState, recurrent: &mut RecurrentState, ) -> Result { + self.prefill_chunk_forward_with_capture(token_ids, kv_state, recurrent, None) + .map(|(hidden, _)| hidden) + } + + pub(crate) fn prefill_chunk_forward_with_capture( + &self, + token_ids: &[u32], + kv_state: &mut KvState, + recurrent: &mut RecurrentState, + capture_layer_ids: Option<&[usize]>, + ) -> Result<(HiddenStates, Option)> { let seq_len = token_ids.len(); debug_assert!( seq_len > 0 && seq_len <= PREFILL_CHUNK_LEN, @@ -172,6 +201,28 @@ impl Qwen35Model { c.head_dim, )?; + let capture_layer_ids = capture_layer_ids.unwrap_or(&[]); + anyhow::ensure!( + capture_layer_ids.windows(2).all(|pair| pair[0] < pair[1]), + "Qwen3.5 DFlash capture layer ids must be strictly increasing" + ); + anyhow::ensure!( + capture_layer_ids + .iter() + .all(|&layer_idx| layer_idx < self.config.num_hidden_layers), + "Qwen3.5 DFlash capture layer id out of range" + ); + let mut captured_hidden = if capture_layer_ids.is_empty() { + None + } else { + Some(HiddenStates::zeros( + &self.ctx, + c.hidden_size * capture_layer_ids.len(), + seq_len, + )?) + }; + let mut next_capture = 0usize; + // Process layers let mut linear_idx = 0usize; let mut full_idx = 0usize; @@ -188,13 +239,143 @@ impl Qwen35Model { &prefill_plan, recurrent, )?; + if capture_layer_ids.get(next_capture) == Some(&layer_idx) { + let out = captured_hidden + .as_mut() + .expect("capture buffer exists when ids are non-empty"); + ops::copy_hidden_rows_into( + &self.ctx, + &hidden_batch, + out, + next_capture * c.hidden_size, + )?; + next_capture += 1; + } } // Advance recurrent token count for the next chunk / decode step; the // paged KV position is tracked by `kv_state` (advanced above). recurrent.seq_len += seq_len; - Ok(hidden_batch) + Ok((hidden_batch, captured_hidden)) + } + + pub(crate) fn prefill_verify_into( + &self, + spans: &[&[u32]], + kv_states: &mut [&mut KvState], + recurrent_states: &mut [&mut RecurrentState], + capture_layer_ids: &[usize], + bufs: &mut VerifyBuffers35, + ) -> Result<()> { + anyhow::ensure!(!spans.is_empty(), "Qwen3.5 verify needs at least one span"); + anyhow::ensure!( + spans.len() == kv_states.len() && spans.len() == recurrent_states.len(), + "Qwen3.5 verify spans/KV/recurrent mismatch: spans={}, kv={}, recurrent={}", + spans.len(), + kv_states.len(), + recurrent_states.len() + ); + anyhow::ensure!( + spans.len() <= bufs.max_batch(), + "Qwen3.5 verify batch {} exceeds buffer capacity {}", + spans.len(), + bufs.max_batch() + ); + anyhow::ensure!( + capture_layer_ids.windows(2).all(|pair| pair[0] < pair[1]), + "Qwen3.5 verify capture layer ids must be strictly increasing" + ); + anyhow::ensure!( + capture_layer_ids + .iter() + .all(|&layer_idx| layer_idx < self.config.num_hidden_layers), + "Qwen3.5 verify capture layer id out of range" + ); + anyhow::ensure!( + bufs.captured_hidden.hidden_dim + == self.config.hidden_size * capture_layer_ids.len().max(1), + "Qwen3.5 verify capture buffer dimension mismatch" + ); + for span in spans { + anyhow::ensure!( + !span.is_empty() && span.len() <= PREFILL_CHUNK_LEN, + "Qwen3.5 verify span len {} out of range", + span.len() + ); + } + + let total_rows = bufs.stage_tokens(&self.ctx, spans)?; + let seq_lens: Vec = spans.iter().map(|span| span.len()).collect(); + let start_positions: Vec = kv_states.iter().map(|kv| kv.seq_len()).collect(); + for (kv, (&base_pos, &seq_len)) in kv_states + .iter_mut() + .zip(start_positions.iter().zip(seq_lens.iter())) + { + let end_pos = + checked_prefill_end_pos(base_pos, seq_len, self.config.max_position_embeddings)?; + self.ensure_rope_cache_covers(end_pos)?; + kv.ensure_capacity(end_pos)?; + kv.advance(seq_len); + } + + let page_indices: Vec> = + kv_states.iter().map(|kv| kv.page_indices_i32()).collect(); + let last_page_lens: Vec = kv_states.iter().map(|kv| kv.last_page_len()).collect(); + bufs.plan.update_batch_with_cta_tile_q( + &self.ctx, + &page_indices, + &last_page_lens, + &start_positions, + &seq_lens, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + 0, + )?; + + ops::embedding_batch( + &self.ctx, + &self.embed_tokens, + &bufs.token_ids_d, + &mut bufs.hidden, + )?; + + let mut linear_idx = 0usize; + let mut full_idx = 0usize; + for (layer_idx, layer) in self.layers.iter().enumerate() { + self.prefill_verify_layer_into( + layer_idx, + layer, + &seq_lens, + kv_states, + recurrent_states, + &mut linear_idx, + &mut full_idx, + capture_layer_ids, + bufs, + )?; + } + + for (recurrent, &seq_len) in recurrent_states.iter_mut().zip(seq_lens.iter()) { + recurrent.seq_len += seq_len; + } + + ops::rms_norm_batch_offset_into( + &self.ctx, + &bufs.hidden, + &self.norm, + self.config.rms_norm_eps, + &mut bufs.logits_normed, + )?; + ops::gemm_into( + &self.ctx, + &self.embed_tokens, + &bufs.logits_normed, + &mut bufs.logits, + ); + debug_assert_eq!(bufs.logits.seq_len, total_rows); + Ok(()) } /// Process one layer during prefill. Returns updated hidden_batch. @@ -265,6 +446,321 @@ impl Qwen35Model { ops::add_batch(&self.ctx, &hidden_plus_attn, &mlp_out) } + #[allow(clippy::too_many_arguments)] + fn prefill_verify_layer_into( + &self, + layer_idx: usize, + layer: &TransformerBlock35, + seq_lens: &[usize], + kv_states: &[&mut KvState], + recurrent_states: &mut [&mut RecurrentState], + linear_idx: &mut usize, + full_idx: &mut usize, + capture_layer_ids: &[usize], + bufs: &mut VerifyBuffers35, + ) -> Result<()> { + let eps = self.config.rms_norm_eps; + ops::rms_norm_batch_offset_into( + &self.ctx, + &bufs.hidden, + &layer.input_layernorm, + eps, + &mut bufs.normed, + )?; + + match &layer.attn { + LayerKind::FullAttention(attn) => { + self.prefill_verify_full_attention_into( + attn, seq_lens, kv_states, *full_idx, bufs, + )?; + *full_idx += 1; + } + LayerKind::LinearAttention(attn) => { + self.prefill_verify_linear_attention_into( + attn, + seq_lens, + recurrent_states, + *linear_idx, + bufs, + )?; + *linear_idx += 1; + } + } + + ops::add_batch_into( + &self.ctx, + &bufs.hidden, + &bufs.attn_results, + &mut bufs.hidden_mid, + )?; + ops::rms_norm_batch_offset_into( + &self.ctx, + &bufs.hidden_mid, + &layer.post_attention_layernorm, + eps, + &mut bufs.normed, + )?; + ops::gemm_into( + &self.ctx, + &layer.mlp.gate_up_proj, + &bufs.normed, + &mut bufs.gate_up_out, + ); + ops::silu_mul_fused_batch_into(&self.ctx, &bufs.gate_up_out, &mut bufs.act_out)?; + ops::gemm_into( + &self.ctx, + &layer.mlp.down_proj, + &bufs.act_out, + &mut bufs.mlp_out, + ); + ops::add_batch_into( + &self.ctx, + &bufs.hidden_mid, + &bufs.mlp_out, + &mut bufs.hidden_next, + )?; + std::mem::swap(&mut bufs.hidden, &mut bufs.hidden_next); + + if let Some(slot) = capture_layer_ids.iter().position(|&idx| idx == layer_idx) { + ops::copy_hidden_rows_into( + &self.ctx, + &bufs.hidden, + &mut bufs.captured_hidden, + slot * self.config.hidden_size, + )?; + } + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + fn prefill_verify_full_attention_into( + &self, + attn: &FullAttentionLayer, + _seq_lens: &[usize], + kv_states: &[&mut KvState], + full_idx: usize, + bufs: &mut VerifyBuffers35, + ) -> Result<()> { + let c = &self.config; + let eps = c.rms_norm_eps; + ops::gemm_into(&self.ctx, &attn.q_proj, &bufs.normed, &mut bufs.q_full); + ops::gemm_into(&self.ctx, &attn.k_proj, &bufs.normed, &mut bufs.k_full); + ops::gemm_into(&self.ctx, &attn.v_proj, &bufs.normed, &mut bufs.v_full); + + let layout = kv_states[0].layout(); + let layer_k_off = (full_idx * layout.layer_stride) as i64; + let layer_v_off = layer_k_off + layout.kv_block_len as i64; + let stride_page = layout.page_stride as i64; + unsafe { + let (qf_ptr, _) = bufs.q_full.data.device_ptr(&self.ctx.stream); + let (k_ptr, _) = bufs.k_full.data.device_ptr(&self.ctx.stream); + let (v_ptr, _) = bufs.v_full.data.device_ptr(&self.ctx.stream); + let (qn_ptr, _) = attn.q_norm.data.device_ptr(&self.ctx.stream); + let (kn_ptr, _) = attn.k_norm.data.device_ptr(&self.ctx.stream); + let (cos_ptr, _) = self.cos_cache.data.device_ptr(&self.ctx.stream); + let (sin_ptr, _) = self.sin_cache.data.device_ptr(&self.ctx.stream); + let (qp_ptr, _) = bufs.q_prepped.data.device_ptr_mut(&self.ctx.stream); + let (buf_ptr, _) = kv_states[0].buffer().device_ptr(&self.ctx.stream); + let (pi_ptr, _) = bufs.plan.page_indices_d().device_ptr(&self.ctx.stream); + let (pip_ptr, _) = bufs.plan.page_indptr_d().device_ptr(&self.ctx.stream); + let (qi_ptr, _) = bufs.plan.q_indptr_d().device_ptr(&self.ctx.stream); + let (pos_ptr, _) = bufs.plan.positions_d().device_ptr(&self.ctx.stream); + ffi::prefill_attention_hd256_prep_paged_batch_cuda( + qf_ptr as *const ffi::Half, + k_ptr as *const ffi::Half, + v_ptr as *const ffi::Half, + qn_ptr as *const ffi::Half, + kn_ptr as *const ffi::Half, + cos_ptr as *const ffi::Half, + sin_ptr as *const ffi::Half, + qp_ptr as *mut ffi::Half, + buf_ptr as *mut ffi::Half, + layer_k_off, + layer_v_off, + pi_ptr as *const i32, + pip_ptr as *const i32, + qi_ptr as *const i32, + pos_ptr as *const i32, + c.num_attention_heads as i32, + c.num_key_value_heads as i32, + bufs.q_prepped.seq_len as i32, + kv_states.len() as i32, + c.rotary_dim as i32, + eps, + layout.page_size as i32, + stride_page, + self.ctx.stream.cu_stream(), + ); + } + + let sm_scale = 1.0f32 / f32::sqrt(HEAD_DIM as f32); + { + let (buf_ptr, _gbuf) = kv_states[0].buffer().device_ptr(&self.ctx.stream); + let (qp_ptr, _gqp) = bufs.q_prepped.data.device_ptr(&self.ctx.stream); + let (out_ptr, _go) = bufs.attn_out_full.data.device_ptr_mut(&self.ctx.stream); + let (pi_ptr, _gpi) = bufs.plan.page_indices_d().device_ptr(&self.ctx.stream); + let (pip_ptr, _gpip) = bufs.plan.page_indptr_d().device_ptr(&self.ctx.stream); + let (lpl_ptr, _glpl) = bufs.plan.last_page_len_d().device_ptr(&self.ctx.stream); + let (qi_ptr, _gqi) = bufs.plan.q_indptr_d().device_ptr(&self.ctx.stream); + let (ri_ptr, _gri) = bufs.plan.request_indices_d().device_ptr(&self.ctx.stream); + let (qti_ptr, _gqti) = bufs.plan.qo_tile_indices_d().device_ptr(&self.ctx.stream); + let (kti_ptr, _gkti) = bufs.plan.kv_tile_indices_d().device_ptr(&self.ctx.stream); + let (kcs_ptr, _gkcs) = bufs.plan.kv_chunk_size_d().device_ptr(&self.ctx.stream); + let (tnr_ptr, _gtnr) = bufs.plan.total_num_rows_d().device_ptr(&self.ctx.stream); + let result = unsafe { + ffi::batch_prefill_paged_cuda_hd256( + qp_ptr as *const ffi::Half, + out_ptr as *mut ffi::Half, + buf_ptr as *const ffi::Half, + layer_k_off, + layer_v_off, + pi_ptr as *const i32, + pip_ptr as *const i32, + lpl_ptr as *const i32, + qi_ptr as *const i32, + ri_ptr as *const i32, + qti_ptr as *const i32, + kti_ptr as *const i32, + kcs_ptr as *const i32, + tnr_ptr as *const u32, + c.num_attention_heads as i32, + c.num_key_value_heads as i32, + HEAD_DIM as i32, + layout.page_size as i32, + bufs.q_prepped.seq_len as i32, + bufs.plan.batch_size(), + bufs.plan.num_tiles(), + stride_page, + sm_scale, + self.ctx.stream.cu_stream(), + ) + }; + anyhow::ensure!( + result == 0, + "Qwen3.5 verify batch_prefill_paged_cuda_hd256 failed: {result}" + ); + } + + unsafe { + let (qf_ptr, _gqf) = bufs.q_full.data.device_ptr(&self.ctx.stream); + let (out_ptr, _go) = bufs.attn_out_full.data.device_ptr_mut(&self.ctx.stream); + ffi::attention_gate_batch_hd256_cuda( + qf_ptr as *const ffi::Half, + out_ptr as *mut ffi::Half, + c.num_attention_heads as i32, + bufs.logits.seq_len as i32, + self.ctx.stream.cu_stream(), + ); + } + ops::gemm_into( + &self.ctx, + &attn.o_proj, + &bufs.attn_out_full, + &mut bufs.attn_results, + ); + Ok(()) + } + + fn prefill_verify_linear_attention_into( + &self, + attn: &LinearAttentionLayer, + seq_lens: &[usize], + recurrent_states: &mut [&mut RecurrentState], + linear_idx: usize, + bufs: &mut VerifyBuffers35, + ) -> Result<()> { + let c = &self.config; + ops::gemm_into(&self.ctx, &attn.in_proj_qkv, &bufs.normed, &mut bufs.qkv); + ops::gemm_into(&self.ctx, &attn.in_proj_z, &bufs.normed, &mut bufs.z); + ops::gemm_into(&self.ctx, &attn.in_proj_b, &bufs.normed, &mut bufs.b_proj); + ops::gemm_into(&self.ctx, &attn.in_proj_a, &bufs.normed, &mut bufs.a_proj); + + let mut row_offset = 0usize; + for (recurrent, &seq_len) in recurrent_states.iter_mut().zip(seq_lens.iter()) { + let layer_state = &mut recurrent.layers[linear_idx]; + bufs.set_compact_rows(seq_len); + + ops::copy_hidden_token_range_into( + &self.ctx, + &bufs.qkv, + row_offset, + &mut bufs.compact_qkv, + 0, + seq_len, + )?; + + ops::conv1d_prefill_batch_into( + &self.ctx, + &bufs.compact_qkv, + &attn.conv1d_weight, + &mut layer_state.conv_state, + &mut bufs.compact_qkv_conv, + c.linear_conv_kernel_dim, + ); + + ops::copy_hidden_token_range_into( + &self.ctx, + &bufs.b_proj, + row_offset, + &mut bufs.compact_b, + 0, + seq_len, + )?; + ops::copy_hidden_token_range_into( + &self.ctx, + &bufs.a_proj, + row_offset, + &mut bufs.compact_a, + 0, + seq_len, + )?; + + ops::gated_delta_rule_prefill_chunkwise_into( + &self.ctx, + &bufs.compact_qkv_conv, + &bufs.compact_b, + &bufs.compact_a, + &attn.dt_bias, + &attn.a_log, + &mut layer_state.state, + &mut bufs.gdr_scratch, + &mut bufs.compact_gdr, + c.linear_num_key_heads, + c.linear_num_value_heads, + c.linear_key_head_dim, + c.linear_value_head_dim, + )?; + ops::copy_hidden_token_range_into( + &self.ctx, + &bufs.compact_gdr, + 0, + &mut bufs.gdr_out, + row_offset, + seq_len, + )?; + row_offset += seq_len; + } + bufs.gdr_scratch.set_rows(bufs.qkv.seq_len); + + ops::rms_norm_gated_batch_into( + &self.ctx, + &bufs.gdr_out, + &attn.norm_weight, + &bufs.z, + &mut bufs.normed_gated, + c.linear_num_value_heads, + c.linear_value_head_dim, + c.rms_norm_eps, + ); + ops::gemm_into( + &self.ctx, + &attn.out_proj, + &bufs.normed_gated, + &mut bufs.attn_results, + ); + Ok(()) + } + #[allow(clippy::too_many_arguments)] fn prefill_full_attention( &self, diff --git a/openinfer-qwen35-4b/src/prefill_buffers.rs b/openinfer-qwen35-4b/src/prefill_buffers.rs index 81caa092..649c47a4 100644 --- a/openinfer-qwen35-4b/src/prefill_buffers.rs +++ b/openinfer-qwen35-4b/src/prefill_buffers.rs @@ -44,6 +44,10 @@ pub struct GdrChunkwiseScratch35 { /// Per-chunk recurrent state snapshots, fp32: [num_chunks, num_value_heads, key_dim, value_dim] pub chunk_state: CudaSlice, + max_seq_len: usize, + num_value_heads: usize, + key_dim: usize, + value_dim: usize, } impl GdrChunkwiseScratch35 { @@ -103,9 +107,39 @@ impl GdrChunkwiseScratch35 { u: HiddenStates::zeros(ctx, vv_hidden_dim, seq_len)?, v_new: HiddenStates::zeros(ctx, vv_hidden_dim, seq_len)?, chunk_state, + max_seq_len: seq_len, + num_value_heads, + key_dim, + value_dim, }) } + pub(crate) fn set_rows(&mut self, seq_len: usize) { + assert!( + seq_len <= self.max_seq_len, + "Qwen3.5 GDR scratch rows {seq_len} exceeds capacity {}", + self.max_seq_len + ); + self.q_expanded.seq_len = seq_len; + self.k_expanded.seq_len = seq_len; + self.v_raw.seq_len = seq_len; + self.w.seq_len = seq_len; + self.u.seq_len = seq_len; + self.v_new.seq_len = seq_len; + } + + pub(crate) fn gate_capacity(&self) -> usize { + self.max_seq_len * self.num_value_heads + } + + pub(crate) fn chunk_a_capacity(&self) -> usize { + self.max_seq_len * self.num_value_heads * Self::CHUNK_SIZE + } + + pub(crate) fn chunk_state_capacity(&self) -> usize { + Self::num_chunks(self.max_seq_len) * self.num_value_heads * self.value_dim * self.key_dim + } + pub fn num_chunks(seq_len: usize) -> usize { seq_len.div_ceil(Self::CHUNK_SIZE) } diff --git a/openinfer-qwen35-4b/src/recurrent.rs b/openinfer-qwen35-4b/src/recurrent.rs index 9bc207fa..34f04ff6 100644 --- a/openinfer-qwen35-4b/src/recurrent.rs +++ b/openinfer-qwen35-4b/src/recurrent.rs @@ -398,11 +398,14 @@ pub fn gated_delta_rule_prefill_chunkwise_into( let expected_chunk_ai_len = expected_chunk_a_len; let expected_chunk_state_len = GdrChunkwiseScratch35::num_chunks(qkv.seq_len) * num_value_heads * val_dim * key_dim; - assert_eq!(scratch.g_cumsum.len(), expected_gate_len); - assert_eq!(scratch.beta.len(), expected_gate_len); - assert_eq!(scratch.a_tril.len(), expected_chunk_a_len); - assert_eq!(scratch.a_inv.len(), expected_chunk_ai_len); - assert_eq!(scratch.chunk_state.len(), expected_chunk_state_len); + assert!(scratch.g_cumsum.len() >= expected_gate_len); + assert!(scratch.beta.len() >= expected_gate_len); + assert!(scratch.a_tril.len() >= expected_chunk_a_len); + assert!(scratch.a_inv.len() >= expected_chunk_ai_len); + assert!(scratch.chunk_state.len() >= expected_chunk_state_len); + debug_assert!(scratch.gate_capacity() >= expected_gate_len); + debug_assert!(scratch.chunk_a_capacity() >= expected_chunk_a_len); + debug_assert!(scratch.chunk_state_capacity() >= expected_chunk_state_len); gated_delta_rule_prefill_chunk_prepare_into( ctx, diff --git a/openinfer-qwen35-4b/src/recurrent_state.rs b/openinfer-qwen35-4b/src/recurrent_state.rs index df32960f..3ca3077c 100644 --- a/openinfer-qwen35-4b/src/recurrent_state.rs +++ b/openinfer-qwen35-4b/src/recurrent_state.rs @@ -52,4 +52,38 @@ impl RecurrentState { Ok(Self { layers, seq_len: 0 }) } + + /// D2D copy all recurrent and convolution state from `src`. + pub(crate) fn copy_from(&mut self, ctx: &DeviceContext, src: &RecurrentState) -> Result<()> { + anyhow::ensure!( + self.layers.len() == src.layers.len(), + "Qwen3.5 recurrent copy layer mismatch: dst={}, src={}", + self.layers.len(), + src.layers.len() + ); + for (layer_idx, (dst_layer, src_layer)) in + self.layers.iter_mut().zip(src.layers.iter()).enumerate() + { + anyhow::ensure!( + dst_layer.state.len() == src_layer.state.len(), + "Qwen3.5 recurrent state length mismatch at layer {layer_idx}: dst={}, src={}", + dst_layer.state.len(), + src_layer.state.len() + ); + anyhow::ensure!( + dst_layer.conv_state.len == src_layer.conv_state.len, + "Qwen3.5 conv state length mismatch at layer {layer_idx}: dst={}, src={}", + dst_layer.conv_state.len, + src_layer.conv_state.len + ); + ctx.stream + .memcpy_dtod(&src_layer.state, &mut dst_layer.state) + .map_err(|e| anyhow::anyhow!("copy recurrent state layer {layer_idx}: {e}"))?; + ctx.stream + .memcpy_dtod(&src_layer.conv_state.data, &mut dst_layer.conv_state.data) + .map_err(|e| anyhow::anyhow!("copy conv state layer {layer_idx}: {e}"))?; + } + self.seq_len = src.seq_len; + Ok(()) + } } diff --git a/openinfer-qwen35-4b/src/scheduler.rs b/openinfer-qwen35-4b/src/scheduler.rs index 85bba5eb..2e5e64a3 100644 --- a/openinfer-qwen35-4b/src/scheduler.rs +++ b/openinfer-qwen35-4b/src/scheduler.rs @@ -6,6 +6,8 @@ mod plan; +use std::collections::HashMap; +use std::path::PathBuf; use std::sync::mpsc as std_mpsc; use std::thread; @@ -16,8 +18,11 @@ use rand::rngs::StdRng; use tokio::sync::mpsc; use crate::batch_decode_graph::BatchDecodeGraphState; +use crate::dflash::{DFlashBatchScratch, DFlashDraftModel, DFlashRequestState}; use crate::logprobs::snapshot_requested_logprobs; use crate::recurrent_state::RecurrentState; +use crate::speculative::{VerifiedToken, accept_greedy}; +use crate::verify_buffers::VerifyBuffers35; use crate::weights::Qwen35Model; use openinfer_core::engine::{ EngineHandle as SchedulerHandle, FinishReason, GenerateRequest as SchedulerRequest, KvCapacity, @@ -33,11 +38,18 @@ use self::plan::{ slot_for_new_request, }; +const DFLASH_MIN_STATS_DRAFT_TOKENS: usize = 4; +const DFLASH_MIN_ACCEPT_RATE: f64 = 0.15; +const DFLASH_PROBE_DRAFT_TOKENS: usize = 4; +const DFLASH_FULL_BLOCK_ACCEPT_RATE: f64 = 0.60; +const DFLASH_MIN_CONTEXT_TOKENS: usize = 16; + // ── Internal types ────────────────────────────────────────────────────── /// An in-flight request being decoded. Recurrent state lives in the /// `BatchDecodeGraphState` at `graph_slot_idx` — NOT owned here. struct ActiveRequest35 { + local_id: usize, request_id: Option, token_tx: TokenSink, kv: KvState, @@ -52,10 +64,128 @@ struct ActiveRequest35 { logprobs: usize, } +struct DFlashSchedulerState { + model: DFlashDraftModel, + requests: HashMap, + stats: HashMap, + scratch: DFlashBatchScratch, + sample: openinfer_sample::SampleScratch, + verify_bufs: VerifyBuffers35, + backup_states: Vec, + verify_scratch_states: Vec, + commit_scratch_states: Vec, + verified_draft_tokens: usize, + accepted_draft_tokens: usize, + commit_stats: DFlashCommitStats, +} + +#[derive(Clone, Copy, Default)] +struct DFlashRequestStats { + verified_draft_tokens: usize, + accepted_draft_tokens: usize, +} + +#[derive(Clone, Copy, Default)] +struct DFlashCommitStats { + full_span_rounds: usize, + same_partial_rounds: usize, + batched_replay_rounds: usize, + batched_replay_rows: usize, + prefix_state_rounds: usize, + prefix_state_rows: usize, +} + +impl DFlashSchedulerState { + fn new(target: &Qwen35Model, draft_path: &str, max_batch: usize) -> Result { + let model = + DFlashDraftModel::from_safetensors_for_target(target.device_ctx(), draft_path, target)?; + if std::env::var_os("OPENINFER_QWEN35_DFLASH_TUNE_GEMM").is_some() { + model.tune_gemm_algos(target)?; + } + let scratch = model.new_batch_scratch(target.device_ctx(), max_batch)?; + let sample = openinfer_sample::SampleScratch::new( + target.device_ctx(), + target.config().vocab_size, + max_batch * model.verify_span(), + )?; + let verify_bufs = VerifyBuffers35::new( + target.device_ctx(), + target.config(), + max_batch, + model.verify_span(), + model.target_layer_ids().len(), + target.kv_pool().capacity_pages(), + )?; + Ok(Self { + model, + requests: HashMap::new(), + stats: HashMap::new(), + scratch, + sample, + verify_bufs, + backup_states: Vec::new(), + verify_scratch_states: Vec::new(), + commit_scratch_states: Vec::new(), + verified_draft_tokens: 0, + accepted_draft_tokens: 0, + commit_stats: DFlashCommitStats::default(), + }) + } + + fn capture_layer_ids(&self) -> &[usize] { + self.model.target_layer_ids() + } + + fn usable_context_tokens(&self, target_max_position_embeddings: usize) -> usize { + target_max_position_embeddings.min( + self.model + .max_position_embeddings() + .saturating_sub(self.model.block_size()), + ) + } + + fn drop_request(&mut self, local_id: usize) { + self.requests.remove(&local_id); + self.stats.remove(&local_id); + } + + fn pending_context_len(&self, local_id: usize) -> Option { + self.requests + .get(&local_id) + .and_then(DFlashRequestState::pending_context_len) + } + + fn ready_for_draft(&self, local_id: usize) -> bool { + self.pending_context_len(local_id) + .is_some_and(|len| len >= DFLASH_MIN_CONTEXT_TOKENS) + } + + fn ensure_state_scratch( + &mut self, + ctx: &openinfer_core::tensor::DeviceContext, + config: &crate::config::Config35, + batch: usize, + ) -> Result<()> { + while self.backup_states.len() < batch { + self.backup_states.push(RecurrentState::new(ctx, config)?); + } + while self.verify_scratch_states.len() < batch { + self.verify_scratch_states + .push(RecurrentState::new(ctx, config)?); + } + while self.commit_scratch_states.len() < batch { + self.commit_scratch_states + .push(RecurrentState::new(ctx, config)?); + } + Ok(()) + } +} + /// A request whose prompt is being prefilled across multiple scheduler steps. /// It owns its growing KV and recurrent state until the prompt is exhausted, /// at which point it is promoted into the decode batch. struct PrefillingRequest35 { + local_id: usize, req: SchedulerRequest, kv: KvState, rec: RecurrentState, @@ -78,6 +208,16 @@ pub fn start_with_capacity( seed: u64, max_batch: usize, max_prefill_tokens: usize, +) -> Result { + start_with_capacity_and_dflash(model, seed, max_batch, max_prefill_tokens, None) +} + +pub(crate) fn start_with_capacity_and_dflash( + model: Qwen35Model, + seed: u64, + max_batch: usize, + max_prefill_tokens: usize, + dflash_draft_model_path: Option, ) -> Result { assert!( max_prefill_tokens > 0, @@ -101,8 +241,32 @@ pub fn start_with_capacity( .name("scheduler-qwen35".into()) .spawn(move || match bind_model_thread(&model) { Ok(_guard) => { + let dflash = match dflash_draft_model_path + .as_ref() + .map(|path| { + path.to_str() + .ok_or_else(|| { + anyhow::anyhow!("DFlash draft model path must be valid UTF-8") + }) + .and_then(|path| DFlashSchedulerState::new(&model, path, max_batch)) + }) + .transpose() + { + Ok(dflash) => dflash, + Err(err) => { + let _ = startup_tx.send(Err(err)); + return; + } + }; let _ = startup_tx.send(Ok(())); - scheduler_loop(model, graph_state, submit_rx, seed, max_prefill_tokens); + scheduler_loop( + model, + graph_state, + submit_rx, + seed, + max_prefill_tokens, + dflash, + ); } Err(err) => { let _ = startup_tx.send(Err(err)); @@ -175,12 +339,14 @@ fn scheduler_loop( mut submit_rx: mpsc::UnboundedReceiver, seed: u64, prefill_budget: usize, + mut dflash: Option, ) { let mut rng = StdRng::seed_from_u64(seed); let mut active: Vec = Vec::new(); let mut deferred: Vec = Vec::new(); let mut prefilling: Vec = Vec::new(); let max_batch = graph_state.slot_states.len(); + let mut next_local_id = 0usize; info!("scheduler ready (max_batch={})", max_batch); @@ -238,7 +404,11 @@ fn scheduler_loop( // KvPool capacity includes the CUDA Graph padding page reserved at // construction, so a real request can use at most the remaining pages. model.kv_pool().capacity_pages().saturating_sub(1), - model.config().max_position_embeddings, + dflash + .as_ref() + .map_or(model.config().max_position_embeddings, |state| { + state.usable_context_tokens(model.config().max_position_embeddings) + }), |req| req.prompt_tokens.len(), |req| req.max_tokens, ); @@ -256,6 +426,13 @@ fn scheduler_loop( ); match RecurrentState::new(model.device_ctx(), model.config()) { Ok(rec) => prefilling.push(PrefillingRequest35 { + local_id: { + let id = next_local_id; + next_local_id = next_local_id + .checked_add(1) + .expect("Qwen3.5 scheduler local request id exhausted"); + id + }, kv: model.alloc_kv(), rec, cursor: 0, @@ -278,7 +455,15 @@ fn scheduler_loop( // 5. Take this step's budgeted prefill chunk off the front of the queue, // then dispatch by plan. let scheduled = take_prefill_chunks(&mut prefilling, prefill_budget); - if let Some(plan) = plan::build_next_plan(!active.is_empty(), scheduled) { + let may_capture_dflash_prefill = true; + let force_prefill_for_dflash = dflash.is_some() + && may_capture_dflash_prefill + && scheduled + .iter() + .any(|pending| should_capture_dflash_prefill_context(&pending.req)); + if let Some(plan) = + plan::build_next_plan(!active.is_empty() && !force_prefill_for_dflash, scheduled) + { match plan { ExecutionPlan::Unified { pending } => unified_step_sched( &model, @@ -287,6 +472,7 @@ fn scheduler_loop( &mut prefilling, &mut graph_state, &mut rng, + dflash.as_mut(), ), ExecutionPlan::Prefill { pending } => prefill_batch( &model, @@ -295,9 +481,23 @@ fn scheduler_loop( &mut prefilling, &mut graph_state, &mut rng, + dflash.as_mut(), ), ExecutionPlan::Decode => { - decode_step(&model, &mut active, &mut graph_state, &mut rng); + if !decode_step_speculative( + &model, + &mut active, + &mut graph_state, + dflash.as_mut(), + ) { + decode_step( + &model, + &mut active, + &mut graph_state, + &mut rng, + dflash.as_mut(), + ); + } } } } @@ -336,16 +536,30 @@ fn prefill_batch( prefilling: &mut Vec, graph_state: &mut BatchDecodeGraphState, rng: &mut StdRng, + dflash: Option<&mut DFlashSchedulerState>, ) { let mut chunk = ScheduledChunk::from(scheduled); + let may_capture_dflash_prefill = true; + let should_capture_dflash = dflash.is_some() + && may_capture_dflash_prefill + && chunk.reqs.iter().any(should_capture_dflash_prefill_context); + let capture_layer_ids = dflash + .as_ref() + .filter(|_| should_capture_dflash) + .map(|d| d.capture_layer_ids()); // Scope the borrows of `chunk` to the executor call so the error path can // move `chunk` into `fail_chunk`. let result = { let window_refs: Vec<&[u32]> = chunk.windows.iter().map(|w| w.as_slice()).collect(); let mut rec_refs: Vec<&mut RecurrentState> = chunk.recs.iter_mut().collect(); - model.batch_prefill_logits(&window_refs, &mut chunk.kvs, &mut rec_refs) + model.batch_prefill_logits_with_capture( + &window_refs, + &mut chunk.kvs, + &mut rec_refs, + capture_layer_ids, + ) }; - let logits = match result { + let (logits, captured_hidden) = match result { Ok(v) => v, Err(e) => { warn!("batch prefill failed: {e}"); @@ -353,6 +567,22 @@ fn prefill_batch( return; } }; + let mut dflash = dflash; + if let Some(dflash) = dflash.as_mut() { + if should_capture_dflash { + if let Err(e) = + record_dflash_prefill_context(model, &mut chunk, dflash, captured_hidden.as_ref()) + { + warn!("DFlash prefill context failed: {e}"); + fail_chunk(chunk, &e.to_string()); + return; + } + } else { + for local_id in &chunk.local_ids { + dflash.drop_request(*local_id); + } + } + } let (tokens, logprobs_vec) = match sample_prefill_logits(model, &chunk.reqs, &logits, graph_state, rng) { @@ -372,6 +602,7 @@ fn prefill_batch( chunk, &tokens, &logprobs_vec, + dflash, ); } @@ -414,6 +645,62 @@ fn sample_prefill_logits( Ok((tokens, logprobs)) } +fn record_dflash_prefill_context( + model: &Qwen35Model, + chunk: &mut ScheduledChunk, + dflash: &mut DFlashSchedulerState, + captured_hidden: Option<&HiddenStates>, +) -> Result<()> { + let captured_hidden = captured_hidden.ok_or_else(|| { + anyhow::anyhow!("DFlash prefill capture requested but no hidden returned") + })?; + let expected_tokens: usize = chunk.windows.iter().map(Vec::len).sum(); + anyhow::ensure!( + captured_hidden.seq_len == expected_tokens, + "Qwen3.5 DFlash captured {} hidden rows for {} scheduled tokens", + captured_hidden.seq_len, + expected_tokens + ); + let mut token_offset = 0usize; + for (i, req) in chunk.reqs.iter().enumerate() { + let local_id = chunk.local_ids[i]; + let chunk_start = chunk.ends[i] - chunk.windows[i].len(); + let capture_supported = req.logprobs == 0 && !req.echo && req.params.is_greedy(); + if capture_supported { + let max_cache_len = + (req.prompt_tokens.len() + req.max_tokens + dflash.model.block_size()) + .min(dflash.model.max_position_embeddings()); + let mut state = match dflash.requests.remove(&local_id) { + Some(state) => state, + None => dflash + .model + .new_request_state(model.device_ctx(), max_cache_len)?, + }; + let pending_len = state.pending_context_len().unwrap_or(0); + anyhow::ensure!( + pending_len == chunk_start, + "Qwen3.5 DFlash prefill context for local request {local_id} is discontinuous: pending={pending_len}, chunk_start={chunk_start}" + ); + dflash.model.append_pending_context( + model.device_ctx(), + &mut state, + captured_hidden, + token_offset, + chunk.windows[i].len(), + )?; + dflash.requests.insert(local_id, state); + } else { + dflash.requests.remove(&local_id); + } + token_offset += chunk.windows[i].len(); + } + Ok(()) +} + +fn should_capture_dflash_prefill_context(req: &SchedulerRequest) -> bool { + req.logprobs == 0 && !req.echo && req.params.is_greedy() +} + // ── Unified step (prefill chunk + decode in one forward pass) ────────────── fn unified_step_sched( @@ -423,6 +710,7 @@ fn unified_step_sched( prefilling: &mut Vec, graph_state: &mut BatchDecodeGraphState, rng: &mut StdRng, + mut dflash: Option<&mut DFlashSchedulerState>, ) { let mut chunk = ScheduledChunk::from(scheduled); // Scope the borrows of `chunk` / `active` to the executor call so the error @@ -461,7 +749,16 @@ fn unified_step_sched( // Process decode results FIRST (it may retire requests and free graph slots // that promotion then fills densely). if output.decoded { - process_decode_logits(model, active, graph_state, rng); + // Unified decode currently reuses the normal graph decode path, which does + // not capture hidden context for DFlash. Drop any active DFlash state before + // dispatch so a mixed prefill/decode step cannot leave a one-token gap in + // the drafter context. + if let Some(dflash) = dflash.as_mut() { + for req in active.iter() { + dflash.drop_request(req.local_id); + } + } + process_decode_logits(model, active, graph_state, rng, dflash); } let prefill_logits = output @@ -486,6 +783,7 @@ fn unified_step_sched( chunk, &tokens, &logprobs_vec, + None, ); } @@ -496,11 +794,32 @@ fn decode_step( active: &mut Vec, graph_state: &mut BatchDecodeGraphState, rng: &mut StdRng, + mut dflash: Option<&mut DFlashSchedulerState>, ) { let token_ids: Vec = active.iter().map(|r| r.last_token).collect(); + let may_capture_dflash_decode = true; + let capture_layer_ids = dflash + .as_ref() + .filter(|d| { + may_capture_dflash_decode + && active + .iter() + .any(|req| d.requests.contains_key(&req.local_id)) + }) + .map(|d| d.capture_layer_ids().to_vec()); let mut kv_refs: Vec<&mut KvState> = active.iter_mut().map(|r| &mut r.kv).collect(); - if let Err(e) = model.batch_decode_graph(&token_ids, &mut kv_refs, graph_state) { + let decode_result = if let Some(capture_layer_ids) = capture_layer_ids.as_deref() { + model.batch_decode_graph_with_capture( + &token_ids, + &mut kv_refs, + graph_state, + Some(capture_layer_ids), + ) + } else { + model.batch_decode_graph(&token_ids, &mut kv_refs, graph_state) + }; + if let Err(e) = decode_result { warn!("batch_decode_graph error: {e}"); let message = e.to_string(); for req in active.drain(..) { @@ -565,7 +884,568 @@ fn decode_step( }) .collect(); - dispatch_decode_tokens(model, active, &tokens, &logprobs_vec, graph_state); + if let Some(dflash) = dflash.as_mut() { + if capture_layer_ids.is_some() { + if let Err(e) = record_dflash_decode_context( + model, + active, + dflash, + &graph_state.buffers.captured_hidden, + ) { + warn!("DFlash decode context failed: {e}"); + for req in active.iter() { + dflash.drop_request(req.local_id); + } + } + } + } + dispatch_decode_tokens(model, active, &tokens, &logprobs_vec, graph_state, dflash); +} + +fn decode_step_speculative( + model: &Qwen35Model, + active: &mut Vec, + graph_state: &mut BatchDecodeGraphState, + dflash: Option<&mut DFlashSchedulerState>, +) -> bool { + let Some(dflash) = dflash else { + return false; + }; + // Multi-active Qwen3.5 DFlash is enabled for greedy/logprobs-free rows once + // each row has enough captured hidden context for the drafter. + if active.iter().any(|req| { + req.logprobs != 0 + || !req.params.is_greedy() + || req.max_tokens.saturating_sub(req.generated_count) <= 1 + || !dflash.ready_for_draft(req.local_id) + }) { + return false; + } + let draft_spans = match execute_dflash_draft(model, active, dflash) { + Ok(spans) => spans, + Err(e) => { + warn!("Qwen3.5 DFlash draft failed, falling back to decode: {e}"); + return false; + } + }; + let verify = match verify_dflash_spans(model, active, graph_state, dflash, &draft_spans) { + Ok(verify) => verify, + Err(e) => { + warn!("Qwen3.5 DFlash verify failed: {e}"); + let message = e.to_string(); + for req in active.drain(..) { + dflash.drop_request(req.local_id); + let _ = req.token_tx.send(TokenEvent::Error { + message: message.clone(), + prompt_tokens: req.prompt_len, + completion_tokens: req.generated_count, + }); + } + return true; + } + }; + dispatch_speculative_tokens(model, active, &verify, graph_state, dflash); + true +} + +fn record_dflash_decode_context( + model: &Qwen35Model, + active: &[ActiveRequest35], + dflash: &mut DFlashSchedulerState, + captured_hidden: &HiddenStates, +) -> Result<()> { + anyhow::ensure!( + captured_hidden.seq_len >= active.len(), + "Qwen3.5 DFlash decode captured {} rows for {} active requests", + captured_hidden.seq_len, + active.len() + ); + for (slot_idx, req) in active.iter().enumerate() { + if req.logprobs != 0 || !req.params.is_greedy() { + dflash.drop_request(req.local_id); + continue; + } + let Some(state) = dflash.requests.get_mut(&req.local_id) else { + continue; + }; + dflash.model.append_pending_context( + model.device_ctx(), + state, + captured_hidden, + slot_idx, + 1, + )?; + } + Ok(()) +} + +fn execute_dflash_draft( + model: &Qwen35Model, + active: &[ActiveRequest35], + dflash: &mut DFlashSchedulerState, +) -> Result>> { + let block_size = dflash.model.block_size(); + let current_tokens: Vec = active.iter().map(|req| req.last_token).collect(); + let mut taken = Vec::with_capacity(active.len()); + for req in active { + let state = dflash + .requests + .remove(&req.local_id) + .ok_or_else(|| anyhow::anyhow!("missing Qwen3.5 DFlash state for {}", req.local_id))?; + taken.push((req.local_id, state)); + } + let result = (|| -> Result>> { + let mut state_refs: Vec<&mut DFlashRequestState> = + taken.iter_mut().map(|(_, state)| state).collect(); + let logits = dflash.model.draft_logits_batched( + model, + &mut state_refs, + ¤t_tokens, + &mut dflash.scratch, + )?; + anyhow::ensure!( + logits.seq_len == active.len() * block_size, + "Qwen3.5 DFlash logits rows {} != active {} x block {}", + logits.seq_len, + active.len(), + block_size + ); + let greedy = SamplingParams::default(); + let params: Vec<&SamplingParams> = vec![&greedy; logits.seq_len]; + let sampled = openinfer_sample::select_batch( + model.device_ctx(), + logits, + ¶ms, + 0, + &mut dflash.sample, + )?; + let drafts_start = if dflash.model.anchor_first() { 0 } else { 1 }; + let mut spans = Vec::with_capacity(active.len()); + for (i, req) in active.iter().enumerate() { + let remaining = req.max_tokens.saturating_sub(req.generated_count); + if remaining == 0 { + spans.push(vec![req.last_token]); + continue; + } + let block = &sampled[i * block_size..(i + 1) * block_size]; + let drafts = &block[drafts_start..]; + // `accept_greedy` can commit one target bonus token beyond accepted + // drafts. Keep speculative verify disabled for the final output token + // so KV/recurrent state cannot advance past the completion budget. + let draft_budget = remaining.saturating_sub(1); + if draft_budget == 0 { + spans.push(vec![req.last_token]); + continue; + } + let draft_limit = + dflash_verify_draft_limit(dflash, req.local_id, drafts.len()).min(draft_budget); + let mut span = Vec::with_capacity(drafts.len() + 1); + span.push(req.last_token); + span.extend(drafts.iter().take(draft_limit).copied()); + spans.push(span); + } + Ok(spans) + })(); + for (local_id, state) in taken { + dflash.requests.insert(local_id, state); + } + result +} + +fn dflash_verify_draft_limit( + dflash: &DFlashSchedulerState, + local_id: usize, + max_drafts: usize, +) -> usize { + if max_drafts == 0 { + return 0; + } + let Some(stats) = dflash.stats.get(&local_id) else { + return max_drafts.min(DFLASH_PROBE_DRAFT_TOKENS); + }; + if stats.verified_draft_tokens < DFLASH_MIN_STATS_DRAFT_TOKENS { + return max_drafts.min(DFLASH_PROBE_DRAFT_TOKENS); + } + let rate = if stats.verified_draft_tokens == 0 { + 0.0 + } else { + stats.accepted_draft_tokens as f64 / stats.verified_draft_tokens as f64 + }; + if rate >= DFLASH_FULL_BLOCK_ACCEPT_RATE { + max_drafts + } else { + max_drafts.min(DFLASH_PROBE_DRAFT_TOKENS) + } +} + +fn verify_dflash_spans( + model: &Qwen35Model, + active: &mut [ActiveRequest35], + graph_state: &mut BatchDecodeGraphState, + dflash: &mut DFlashSchedulerState, + spans: &[Vec], +) -> Result>> { + dflash.ensure_state_scratch(model.device_ctx(), model.config(), active.len())?; + copy_recurrent_states_into( + model, + active, + graph_state, + &mut dflash.backup_states[..active.len()], + )?; + let capture_layer_ids = dflash.capture_layer_ids().to_vec(); + let original_seq_lens: Vec = active.iter().map(|req| req.kv.seq_len()).collect(); + let result = (|| -> Result>> { + for (slot_idx, span) in spans.iter().enumerate() { + anyhow::ensure!( + span.len() >= 2, + "Qwen3.5 DFlash verify span for local request {} is too short", + active[slot_idx].local_id + ); + dflash.verify_scratch_states[slot_idx] + .copy_from(model.device_ctx(), &dflash.backup_states[slot_idx])?; + } + + let span_refs: Vec<&[u32]> = spans.iter().map(Vec::as_slice).collect(); + let active_len = active.len(); + let mut kv_refs: Vec<&mut KvState> = active.iter_mut().map(|req| &mut req.kv).collect(); + let mut rec_refs: Vec<&mut RecurrentState> = dflash.verify_scratch_states[..active_len] + .iter_mut() + .collect(); + model.prefill_verify_into( + &span_refs, + &mut kv_refs, + &mut rec_refs, + &capture_layer_ids, + &mut dflash.verify_bufs, + )?; + + let total_rows: usize = spans.iter().map(Vec::len).sum(); + let greedy = SamplingParams::default(); + let params: Vec<&SamplingParams> = vec![&greedy; total_rows]; + let target_tokens = openinfer_sample::select_batch( + model.device_ctx(), + &dflash.verify_bufs.logits, + ¶ms, + 0, + &mut dflash.verify_bufs.sample, + )?; + + let mut accepted_rows = Vec::with_capacity(active.len()); + let mut matched_draft_rows = Vec::with_capacity(active.len()); + let mut row_offset = 0usize; + for (slot_idx, req) in active.iter_mut().enumerate() { + let span = &spans[slot_idx]; + let row_end = row_offset + span.len(); + let target_slice = &target_tokens[row_offset..row_end]; + let (matched, accepted_ids) = accept_greedy(&span[1..], target_slice); + if std::env::var_os("OPENINFER_QWEN35_DFLASH_TRACE").is_some() { + log::info!( + "Qwen3.5 DFlash trace local_request={} span_head={:?} target_head={:?} matched_drafts={} accepted_head={:?}", + req.local_id, + &span[..span.len().min(12)], + &target_slice[..target_slice.len().min(12)], + matched, + &accepted_ids[..accepted_ids.len().min(12)], + ); + } + let accepted: Vec = accepted_ids + .into_iter() + .map(|token| VerifiedToken { + token, + logprob: None, + }) + .collect(); + matched_draft_rows.push(matched); + accepted_rows.push(accepted); + row_offset = row_end; + } + + let mut local_ids = Vec::with_capacity(active.len()); + let mut span_lens = Vec::with_capacity(active.len()); + for (slot_idx, req) in active.iter_mut().enumerate() { + let span = &spans[slot_idx]; + let accepted_len = accepted_rows[slot_idx].len(); + anyhow::ensure!( + accepted_len > 0 && accepted_len <= span.len(), + "Qwen3.5 DFlash accepted span {} outside verify span {} for local request {}", + accepted_len, + span.len(), + req.local_id + ); + let committed_drafts = matched_draft_rows[slot_idx]; + dflash.verified_draft_tokens += span.len().saturating_sub(1); + dflash.accepted_draft_tokens += committed_drafts; + let stats = dflash.stats.entry(req.local_id).or_default(); + stats.verified_draft_tokens += span.len().saturating_sub(1); + stats.accepted_draft_tokens += committed_drafts; + local_ids.push(req.local_id); + span_lens.push(span.len()); + } + + let accepted_lens: Vec = accepted_rows.iter().map(Vec::len).collect(); + let accepted_token_ids: Vec> = accepted_rows + .iter() + .map(|row| row.iter().map(|token| token.token).collect()) + .collect(); + let commit_input_token_ids: Vec> = spans + .iter() + .zip(accepted_token_ids.iter()) + .zip(accepted_lens.iter()) + .map(|((span, accepted_ids), &accepted_len)| { + let mut ids = Vec::with_capacity(accepted_len); + ids.push(span[0]); + ids.extend( + accepted_ids + .iter() + .take(accepted_len.saturating_sub(1)) + .copied(), + ); + ids + }) + .collect(); + let mut commit_captured_offsets: Vec = (0..active.len()) + .map(|slot_idx| row_offset_for_span(spans, slot_idx)) + .collect(); + + let batch_keep_speculating = if active.len() > 1 { + let batch_verified: usize = active + .iter() + .filter_map(|req| dflash.stats.get(&req.local_id)) + .map(|stats| stats.verified_draft_tokens) + .sum(); + let batch_accepted: usize = active + .iter() + .filter_map(|req| dflash.stats.get(&req.local_id)) + .map(|stats| stats.accepted_draft_tokens) + .sum(); + batch_verified < DFLASH_MIN_STATS_DRAFT_TOKENS * active.len() + || (batch_accepted as f64 / batch_verified as f64) >= DFLASH_MIN_ACCEPT_RATE + } else { + true + }; + let keep_speculating_by_slot: Vec = (0..active.len()) + .map(|slot_idx| { + if active.len() > 1 { + batch_keep_speculating + } else { + match dflash.stats.get(&local_ids[slot_idx]) { + Some(stats) + if stats.verified_draft_tokens >= DFLASH_MIN_STATS_DRAFT_TOKENS => + { + let rate = stats.accepted_draft_tokens as f64 + / stats.verified_draft_tokens as f64; + rate >= DFLASH_MIN_ACCEPT_RATE + } + _ => true, + } + } + }) + .collect(); + + let full_span_commit = accepted_lens + .iter() + .zip(spans.iter()) + .all(|(&accepted_len, span)| accepted_len == span.len()); + let mut append_after_commit = vec![true; active.len()]; + let commit_path = if full_span_commit { + dflash.commit_stats.full_span_rounds += 1; + for (slot_idx, req) in active.iter_mut().enumerate() { + graph_state.copy_state_to_slot( + model.device_ctx(), + &dflash.verify_scratch_states[slot_idx], + req.graph_slot_idx, + )?; + } + "full_span_state_commit" + } else { + dflash.commit_stats.prefix_state_rounds += 1; + dflash.commit_stats.prefix_state_rows += accepted_lens.iter().sum::(); + let partial_indices: Vec = accepted_lens + .iter() + .zip(spans.iter()) + .enumerate() + .filter_map(|(slot_idx, (&accepted_len, span))| { + (accepted_len < span.len()).then_some(slot_idx) + }) + .collect(); + for slot_idx in 0..active.len() { + if accepted_lens[slot_idx] == spans[slot_idx].len() { + append_after_commit[slot_idx] = false; + if keep_speculating_by_slot[slot_idx] + && let Some(state) = dflash.requests.get_mut(&local_ids[slot_idx]) + { + dflash.model.append_pending_context( + model.device_ctx(), + state, + &dflash.verify_bufs.captured_hidden, + row_offset_for_span(spans, slot_idx), + accepted_lens[slot_idx], + )?; + } + } + } + let mut offset = 0usize; + for &slot_idx in &partial_indices { + commit_captured_offsets[slot_idx] = offset; + offset += accepted_lens[slot_idx]; + dflash.commit_scratch_states[slot_idx] + .copy_from(model.device_ctx(), &dflash.backup_states[slot_idx])?; + active[slot_idx] + .kv + .truncate_to(original_seq_lens[slot_idx])?; + } + if !partial_indices.is_empty() { + let prefix_refs: Vec<&[u32]> = partial_indices + .iter() + .map(|&slot_idx| commit_input_token_ids[slot_idx].as_slice()) + .collect(); + let mut kv_refs: Vec<&mut KvState> = active + .iter_mut() + .enumerate() + .filter_map(|(slot_idx, req)| { + partial_indices.contains(&slot_idx).then_some(&mut req.kv) + }) + .collect(); + let mut rec_refs: Vec<&mut RecurrentState> = dflash + .commit_scratch_states + .iter_mut() + .enumerate() + .filter_map(|(slot_idx, state)| { + partial_indices.contains(&slot_idx).then_some(state) + }) + .collect(); + model.prefill_verify_into( + &prefix_refs, + &mut kv_refs, + &mut rec_refs, + &capture_layer_ids, + &mut dflash.verify_bufs, + )?; + } + for (slot_idx, req) in active.iter_mut().enumerate() { + let state = if accepted_lens[slot_idx] == spans[slot_idx].len() { + &dflash.verify_scratch_states[slot_idx] + } else { + &dflash.commit_scratch_states[slot_idx] + }; + graph_state.copy_state_to_slot(model.device_ctx(), state, req.graph_slot_idx)?; + } + if accepted_lens + .first() + .is_some_and(|&first| accepted_lens.iter().all(|&len| len == first)) + { + dflash.commit_stats.same_partial_rounds += 1; + "same_len_prefix_state_commit" + } else { + dflash.commit_stats.batched_replay_rounds += 1; + dflash.commit_stats.batched_replay_rows += partial_indices + .iter() + .map(|&slot_idx| accepted_lens[slot_idx]) + .sum::(); + "heterogeneous_prefix_state_commit" + } + }; + for slot_idx in 0..active.len() { + if append_after_commit[slot_idx] + && keep_speculating_by_slot[slot_idx] + && let Some(state) = dflash.requests.get_mut(&local_ids[slot_idx]) + { + dflash.model.append_pending_context( + model.device_ctx(), + state, + &dflash.verify_bufs.captured_hidden, + commit_captured_offsets[slot_idx], + accepted_lens[slot_idx], + )?; + } + } + + for slot_idx in 0..active.len() { + let accepted_len = accepted_lens[slot_idx]; + let local_id = local_ids[slot_idx]; + let span_len = span_lens[slot_idx]; + let keep_speculating = keep_speculating_by_slot[slot_idx]; + if !keep_speculating { + dflash.drop_request(local_id); + } + let rate = if dflash.verified_draft_tokens == 0 { + 0.0 + } else { + dflash.accepted_draft_tokens as f64 / dflash.verified_draft_tokens as f64 + }; + let request_rate = dflash.stats.get(&local_id).map_or(0.0, |stats| { + if stats.verified_draft_tokens == 0 { + 0.0 + } else { + stats.accepted_draft_tokens as f64 / stats.verified_draft_tokens as f64 + } + }); + log::debug!( + "Qwen3.5 DFlash local_request={} accepted_tokens={} verified_span={} request_accept_rate={:.3} cumulative_accept_rate={:.3} keep_speculating={} captured_offset={}", + local_id, + accepted_len, + span_len, + request_rate, + rate, + keep_speculating, + commit_captured_offsets[slot_idx], + ); + } + if std::env::var_os("OPENINFER_QWEN35_DFLASH_COMMIT_TRACE").is_some() { + log::info!( + "Qwen3.5 DFlash commit path={} active={} accepted_lens={:?} span_lens={:?} full={} same_partial={} batched_replay={} batched_replay_rows={} prefix_state={} prefix_state_rows={}", + commit_path, + active.len(), + accepted_lens, + span_lens, + dflash.commit_stats.full_span_rounds, + dflash.commit_stats.same_partial_rounds, + dflash.commit_stats.batched_replay_rounds, + dflash.commit_stats.batched_replay_rows, + dflash.commit_stats.prefix_state_rounds, + dflash.commit_stats.prefix_state_rows, + ); + } + + Ok(accepted_rows) + })(); + if result.is_err() { + for ((req, backup_state), &seq_len) in active + .iter_mut() + .zip(dflash.backup_states.iter()) + .zip(original_seq_lens.iter()) + { + let _ = req.kv.truncate_to(seq_len); + let _ = graph_state.copy_state_to_slot( + model.device_ctx(), + backup_state, + req.graph_slot_idx, + ); + } + } + result +} + +fn row_offset_for_span(spans: &[Vec], slot_idx: usize) -> usize { + spans[..slot_idx].iter().map(Vec::len).sum() +} + +fn copy_recurrent_states_into( + model: &Qwen35Model, + active: &[ActiveRequest35], + graph_state: &BatchDecodeGraphState, + states: &mut [RecurrentState], +) -> Result<()> { + anyhow::ensure!( + states.len() >= active.len(), + "Qwen3.5 DFlash backup state capacity {} < active {}", + states.len(), + active.len() + ); + for (req, state) in active.iter().zip(states.iter_mut()) { + graph_state.copy_slot_to_state(model.device_ctx(), req.graph_slot_idx, state)?; + } + Ok(()) } /// Process decode logits from unified step: sample, extract logprobs, dispatch. @@ -574,6 +1454,7 @@ fn process_decode_logits( active: &mut Vec, graph_state: &mut BatchDecodeGraphState, rng: &mut StdRng, + dflash: Option<&mut DFlashSchedulerState>, ) { let requested_logprobs: Vec = active.iter().map(|r| r.logprobs).collect(); let cpu_logits = match snapshot_requested_logprobs( @@ -626,7 +1507,7 @@ fn process_decode_logits( }) .collect(); - dispatch_decode_tokens(model, active, &tokens, &logprobs_vec, graph_state); + dispatch_decode_tokens(model, active, &tokens, &logprobs_vec, graph_state, dflash); } /// Dispatch sampled decode tokens: send events, check EOS/limits, retire finished. @@ -639,6 +1520,7 @@ fn dispatch_decode_tokens( tokens: &[u32], logprobs: &[Option], graph_state: &mut BatchDecodeGraphState, + mut dflash: Option<&mut DFlashSchedulerState>, ) { let n = active.len(); let mut to_retire = Vec::new(); @@ -698,6 +1580,87 @@ fn dispatch_decode_tokens( // Remove in reverse order so compact_slot indices stay valid for &i in to_retire.iter().rev() { + if let Some(dflash) = dflash.as_mut() { + dflash.drop_request(active[i].local_id); + } + compact_slot(model, active, graph_state, i); + } +} + +fn dispatch_speculative_tokens( + model: &Qwen35Model, + active: &mut Vec, + accepted: &[Vec], + graph_state: &mut BatchDecodeGraphState, + dflash: &mut DFlashSchedulerState, +) { + let n = active.len(); + let mut to_retire = Vec::new(); + + for i in 0..n { + let req = &mut active[i]; + let mut should_retire = None; + for token in &accepted[i] { + req.generated_count += 1; + let is_eos = !req.params.ignore_eos && model.is_stop_token(token.token); + let at_limit = req.generated_count >= req.max_tokens; + if is_eos { + debug!( + "request finished: request_id={:?} prompt_tokens={} completion_tokens={} finish_reason={:?}", + req.request_id, + req.prompt_len, + req.generated_count, + FinishReason::Stop + ); + let _ = req.token_tx.send(TokenEvent::Finished { + finish_reason: FinishReason::Stop, + prompt_tokens: req.prompt_len, + completion_tokens: req.generated_count, + }); + should_retire = Some(i); + break; + } + if req + .token_tx + .send(TokenEvent::Token { + id: token.token, + logprob: token.logprob.clone(), + }) + .is_err() + { + debug!( + "request dropped: client disconnected: request_id={:?} tokens_generated={}", + req.request_id, req.generated_count + ); + should_retire = Some(i); + break; + } + if at_limit { + debug!( + "request finished: request_id={:?} prompt_tokens={} completion_tokens={} finish_reason={:?}", + req.request_id, + req.prompt_len, + req.generated_count, + FinishReason::Length + ); + let _ = req.token_tx.send(TokenEvent::Finished { + finish_reason: FinishReason::Length, + prompt_tokens: req.prompt_len, + completion_tokens: req.generated_count, + }); + should_retire = Some(i); + break; + } + req.last_token = token.token; + } + if let Some(idx) = should_retire { + to_retire.push(idx); + } + } + + for &i in to_retire.iter().rev() { + let local_id = active[i].local_id; + dflash.drop_request(local_id); compact_slot(model, active, graph_state, i); } } @@ -722,31 +1685,13 @@ fn compact_slot( let src_slot = active[idx].graph_slot_idx; debug_assert_eq!(src_slot, compaction.moved_from); - // D2D copy: graph_state.slot_states[src] -> graph_state.slot_states[dst] - // We can't borrow two slots mutably at once, so use raw index copy. - let ctx = model.device_ctx(); - let src = &graph_state.slot_states[compaction.moved_from]; - // Copy layer by layer using the public fields - for layer_idx in 0..src.layers.len() { - let (src_part, dst_part) = if compaction.moved_to < compaction.moved_from { - let (left, right) = graph_state.slot_states.split_at_mut(compaction.moved_from); - ( - &right[0].layers[layer_idx], - &mut left[compaction.moved_to].layers[layer_idx], - ) - } else { - unreachable!("idx < active.len() <= last"); - }; - - ctx.stream - .memcpy_dtod(&src_part.state, &mut dst_part.state) - .expect("compact slot state copy failed"); - ctx.stream - .memcpy_dtod(&src_part.conv_state.data, &mut dst_part.conv_state.data) - .expect("compact slot conv_state copy failed"); - } - graph_state.slot_states[compaction.moved_to].seq_len = - graph_state.slot_states[compaction.moved_from].seq_len; + graph_state + .copy_slot_to_slot( + model.device_ctx(), + compaction.moved_from, + compaction.moved_to, + ) + .expect("compact slot recurrent state copy failed"); active[compaction.moved_to].graph_slot_idx = compaction.moved_to; } @@ -756,6 +1701,7 @@ fn compact_slot( /// Step's scheduled prefill set struct ScheduledChunk { + local_ids: Vec, reqs: Vec, kvs: Vec, recs: Vec, @@ -769,6 +1715,7 @@ impl From> for ScheduledChunk { fn from(scheduled: Vec) -> Self { let n = scheduled.len(); let mut chunk = ScheduledChunk { + local_ids: Vec::with_capacity(n), reqs: Vec::with_capacity(n), kvs: Vec::with_capacity(n), recs: Vec::with_capacity(n), @@ -780,6 +1727,7 @@ impl From> for ScheduledChunk { chunk .windows .push(p.req.prompt_tokens[p.cursor..end].to_vec()); + chunk.local_ids.push(p.local_id); chunk.ends.push(end); chunk.reqs.push(p.req); chunk.kvs.push(p.kv); @@ -830,8 +1778,10 @@ fn promote_or_requeue( chunk: ScheduledChunk, tokens: &[u32], logprobs: &[Option], + dflash: Option<&mut DFlashSchedulerState>, ) { let ScheduledChunk { + local_ids, reqs, kvs, recs, @@ -839,11 +1789,20 @@ fn promote_or_requeue( .. } = chunk; let mut still_prefilling: Vec = Vec::new(); + let mut dflash = dflash; - for (i, (((req, kv), rec), end)) in reqs.into_iter().zip(kvs).zip(recs).zip(ends).enumerate() { + for (i, ((((req, kv), rec), end), local_id)) in reqs + .into_iter() + .zip(kvs) + .zip(recs) + .zip(ends) + .zip(local_ids) + .enumerate() + { // Not finished: re-queue with the advanced cursor if end < req.prompt_tokens.len() { still_prefilling.push(PrefillingRequest35 { + local_id, req, kv, rec, @@ -878,6 +1837,9 @@ fn promote_or_requeue( prompt_tokens: prompt_len, completion_tokens: 0, }); + if let Some(dflash) = dflash.as_mut() { + dflash.drop_request(local_id); + } continue; } @@ -893,6 +1855,9 @@ fn promote_or_requeue( "request dropped: client disconnected: request_id={:?} tokens_generated={}", req.request_id, 0 ); + if let Some(dflash) = dflash.as_mut() { + dflash.drop_request(local_id); + } continue; } @@ -909,6 +1874,9 @@ fn promote_or_requeue( prompt_tokens: prompt_len, completion_tokens: 1, }); + if let Some(dflash) = dflash.as_mut() { + dflash.drop_request(local_id); + } continue; } @@ -919,6 +1887,7 @@ fn promote_or_requeue( .copy_state_to_slot(model.device_ctx(), &rec, slot_idx) .expect("copy recurrent state to slot failed"); active.push(ActiveRequest35 { + local_id, request_id: req.request_id, token_tx: req.token_tx, kv, diff --git a/openinfer-qwen35-4b/src/speculative.rs b/openinfer-qwen35-4b/src/speculative.rs new file mode 100644 index 00000000..914d4515 --- /dev/null +++ b/openinfer-qwen35-4b/src/speculative.rs @@ -0,0 +1,338 @@ +use anyhow::Result; +use openinfer_core::engine::TokenLogprob; +use openinfer_core::sampler::SamplingParams; + +use crate::executor::{Qwen35Executor, RequestId}; +use crate::logprobs::snapshot_requested_logprobs; +use crate::recurrent_state::RecurrentState; +use crate::verify_buffers::VerifyBuffers35; + +#[derive(Clone, Debug)] +pub struct VerifyStepItem { + pub request_id: RequestId, + pub token_ids: Vec, + pub logprobs: usize, +} + +impl VerifyStepItem { + pub fn new(request_id: RequestId, token_ids: Vec, logprobs: usize) -> Self { + Self { + request_id, + token_ids, + logprobs, + } + } +} + +#[derive(Clone, Copy)] +pub struct VerifyPlan<'a> { + pub requests: &'a [VerifyStepItem], +} + +#[derive(Clone, Debug, PartialEq)] +pub struct VerifiedToken { + pub token: u32, + pub logprob: Option, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct VerifyRequestResult { + pub request_id: RequestId, + pub matched_draft_tokens: usize, + pub accepted_tokens: Vec, +} + +pub struct VerifyResult { + pub requests: Vec, +} + +#[must_use] +pub(crate) fn accept_greedy(proposed: &[u32], target_argmax: &[u32]) -> (usize, Vec) { + debug_assert_eq!( + target_argmax.len(), + proposed.len() + 1, + "verify must produce one posterior token per draft plus one bonus" + ); + let mut matched = 0usize; + while matched < proposed.len() && proposed[matched] == target_argmax[matched] { + matched += 1; + } + let mut accepted = Vec::with_capacity(matched + 1); + accepted.extend_from_slice(&proposed[..matched]); + accepted.push(target_argmax[matched]); + (matched, accepted) +} + +impl Qwen35Executor { + pub fn execute_speculative_verify(&mut self, plan: VerifyPlan<'_>) -> Result { + self.validate_speculative_verify(plan)?; + let original_seq_lens: Vec = self + .active + .iter() + .map(|active| active.kv.seq_len()) + .collect(); + match self.execute_speculative_verify_inner(plan) { + Ok(result) => Ok(result), + Err(err) => { + if let Err(rollback_err) = self.rollback_kv_states(&original_seq_lens) { + anyhow::bail!( + "{err}; additionally failed to roll back Qwen3.5 KV state: {rollback_err}" + ); + } + Err(err) + } + } + } + + fn validate_speculative_verify(&self, plan: VerifyPlan<'_>) -> Result<()> { + anyhow::ensure!( + !plan.requests.is_empty(), + "Qwen3.5 speculative verify plan requires at least one request" + ); + anyhow::ensure!( + plan.requests.len() == self.active.len(), + "Qwen3.5 speculative verify must include all active requests in slot order" + ); + for (slot_idx, req) in plan.requests.iter().enumerate() { + anyhow::ensure!( + self.active[slot_idx].request_id == req.request_id, + "Qwen3.5 speculative verify request order differs from active slot order" + ); + anyhow::ensure!( + req.token_ids.len() >= 2, + "Qwen3.5 speculative verify request {} needs [current, draft...]", + req.request_id.get() + ); + } + Ok(()) + } + + fn execute_speculative_verify_inner(&mut self, plan: VerifyPlan<'_>) -> Result { + let backup_states = self.copy_canonical_recurrent_states()?; + let original_seq_lens: Vec = self + .active + .iter() + .map(|active| active.kv.seq_len()) + .collect(); + + let mut verify_states = Vec::with_capacity(self.active.len()); + for backup in &backup_states { + let mut state = RecurrentState::new(self.model.device_ctx(), self.model.config())?; + state.copy_from(self.model.device_ctx(), backup)?; + verify_states.push(state); + } + + let max_span = plan + .requests + .iter() + .map(|req| req.token_ids.len()) + .max() + .unwrap_or(1); + let mut verify_bufs = VerifyBuffers35::new( + self.model.device_ctx(), + self.model.config(), + plan.requests.len(), + max_span, + 0, + self.model.kv_pool().capacity_pages(), + )?; + let spans: Vec<&[u32]> = plan + .requests + .iter() + .map(|req| req.token_ids.as_slice()) + .collect(); + { + let mut kv_refs: Vec<_> = self + .active + .iter_mut() + .map(|active| &mut active.kv) + .collect(); + let mut rec_refs: Vec<_> = verify_states.iter_mut().collect(); + self.model.prefill_verify_into( + &spans, + &mut kv_refs, + &mut rec_refs, + &[], + &mut verify_bufs, + )?; + } + + let requested_logprobs: Vec = plan + .requests + .iter() + .flat_map(|req| std::iter::repeat_n(req.logprobs, req.token_ids.len())) + .collect(); + let cpu_logits = snapshot_requested_logprobs( + self.model.device_ctx(), + &verify_bufs.logits, + &requested_logprobs, + )?; + let params = vec![SamplingParams::default(); verify_bufs.logits.seq_len]; + let params_refs: Vec<&SamplingParams> = params.iter().collect(); + let target_tokens = openinfer_sample::select_batch( + self.model.device_ctx(), + &verify_bufs.logits, + ¶ms_refs, + 0, + &mut verify_bufs.sample, + )?; + + let mut request_outputs = Vec::with_capacity(plan.requests.len()); + let mut row_offset = 0usize; + for req in plan.requests { + let row_end = row_offset + req.token_ids.len(); + let target_slice = &target_tokens[row_offset..row_end]; + let target_logprobs = target_slice + .iter() + .enumerate() + .map(|(i, &token)| { + cpu_logits[row_offset + i].as_ref().and_then(|row| { + openinfer_sample::token_logprob_from_row(row, token, req.logprobs) + }) + }) + .collect::>(); + let (matched, accepted_ids) = accept_greedy(&req.token_ids[1..], target_slice); + let accepted_tokens: Vec = accepted_ids + .iter() + .enumerate() + .map(|(i, &token)| VerifiedToken { + token, + logprob: target_logprobs.get(i).cloned().unwrap_or(None), + }) + .collect(); + request_outputs.push(VerifyRequestResult { + request_id: req.request_id, + matched_draft_tokens: matched, + accepted_tokens, + }); + row_offset = row_end; + } + + for (active, &seq_len) in self.active.iter_mut().zip(original_seq_lens.iter()) { + active.kv.truncate_to(seq_len)?; + } + + if let Err(err) = + self.commit_speculative_states(plan.requests, &request_outputs, &backup_states) + { + if let Err(restore_err) = + self.restore_canonical_states(&backup_states, &original_seq_lens) + { + anyhow::bail!( + "{err}; additionally failed to restore Qwen3.5 recurrent/conv state after speculative commit failure: {restore_err}" + ); + } + return Err(err); + } + + Ok(VerifyResult { + requests: request_outputs, + }) + } + + fn copy_canonical_recurrent_states(&self) -> Result> { + let mut scratch_states = Vec::with_capacity(self.active.len()); + for active in &self.active { + let mut state = RecurrentState::new(self.model.device_ctx(), self.model.config())?; + self.graph_state.copy_slot_to_state( + self.model.device_ctx(), + active.graph_slot_idx, + &mut state, + )?; + scratch_states.push(state); + } + Ok(scratch_states) + } + + fn commit_speculative_states( + &mut self, + requests: &[VerifyStepItem], + results: &[VerifyRequestResult], + backup_states: &[RecurrentState], + ) -> Result<()> { + for (slot_idx, ((req, result), backup_state)) in requests + .iter() + .zip(results.iter()) + .zip(backup_states.iter()) + .enumerate() + { + let accepted_len = result.accepted_tokens.len(); + anyhow::ensure!( + accepted_len <= req.token_ids.len(), + "Qwen3.5 speculative accepted span {} exceeds verify span {}", + accepted_len, + req.token_ids.len() + ); + let mut replay_state = + RecurrentState::new(self.model.device_ctx(), self.model.config())?; + replay_state.copy_from(self.model.device_ctx(), backup_state)?; + let replay_tokens = &req.token_ids[..accepted_len]; + let _ = self.model.prefill_logits_all( + replay_tokens, + &mut self.active[slot_idx].kv, + &mut replay_state, + )?; + let graph_slot_idx = self.active[slot_idx].graph_slot_idx; + self.graph_state.copy_state_to_slot( + self.model.device_ctx(), + &replay_state, + graph_slot_idx, + )?; + } + Ok(()) + } + + fn restore_canonical_states( + &mut self, + backup_states: &[RecurrentState], + original_seq_lens: &[usize], + ) -> Result<()> { + for ((active, backup_state), &seq_len) in self + .active + .iter_mut() + .zip(backup_states.iter()) + .zip(original_seq_lens.iter()) + { + active.kv.truncate_to(seq_len)?; + self.graph_state.copy_state_to_slot( + self.model.device_ctx(), + backup_state, + active.graph_slot_idx, + )?; + } + Ok(()) + } + + fn rollback_kv_states(&mut self, original_seq_lens: &[usize]) -> Result<()> { + for (active, &seq_len) in self.active.iter_mut().zip(original_seq_lens.iter()) { + active.kv.truncate_to(seq_len)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn accepts_full_run_plus_bonus() { + let (matched, accepted) = accept_greedy(&[10, 11, 12], &[10, 11, 12, 13]); + assert_eq!(matched, 3); + assert_eq!(accepted, vec![10, 11, 12, 13]); + } + + #[test] + fn accepts_prefix_then_correction() { + let (matched, accepted) = accept_greedy(&[10, 11, 99], &[10, 11, 22, 33]); + assert_eq!(matched, 2); + assert_eq!(accepted, vec![10, 11, 22]); + } + + #[test] + fn rejects_first_candidate_commits_one() { + let (matched, accepted) = accept_greedy(&[10, 11, 12], &[7, 8, 9, 10]); + assert_eq!(matched, 0); + assert_eq!(accepted, vec![7]); + } +} diff --git a/openinfer-qwen35-4b/src/unified_forward.rs b/openinfer-qwen35-4b/src/unified_forward.rs index 825532f3..20930ebd 100644 --- a/openinfer-qwen35-4b/src/unified_forward.rs +++ b/openinfer-qwen35-4b/src/unified_forward.rs @@ -14,6 +14,7 @@ use super::batch_decode_graph::BatchDecodeGraphState; use super::recurrent_state::RecurrentState; use super::weights::Qwen35Model; use openinfer_core::kv_pool::KvState; +use openinfer_core::ops; use openinfer_core::tensor::HiddenStates; pub(crate) struct UnifiedStepOutput { @@ -32,6 +33,17 @@ impl Qwen35Model { kv_states: &mut [KvState], recurrent_states: &mut [&mut RecurrentState], ) -> Result { + self.batch_prefill_logits_with_capture(prompts, kv_states, recurrent_states, None) + .map(|(logits, _)| logits) + } + + pub(crate) fn batch_prefill_logits_with_capture( + &self, + prompts: &[&[u32]], + kv_states: &mut [KvState], + recurrent_states: &mut [&mut RecurrentState], + capture_layer_ids: Option<&[usize]>, + ) -> Result<(HiddenStates, Option)> { let n = prompts.len(); anyhow::ensure!(n > 0, "batch_prefill requires at least one prompt"); anyhow::ensure!(n == kv_states.len(), "prompts / kv_states len mismatch"); @@ -41,16 +53,46 @@ impl Qwen35Model { ); let mut last_hiddens = Vec::with_capacity(n); + let capture_hidden_dim = capture_layer_ids + .map(|ids| ids.len() * self.config.hidden_size) + .unwrap_or(0); + let capture_tokens: usize = prompts.iter().map(|prompt| prompt.len()).sum(); + let mut captured_all = if capture_layer_ids.is_some() { + Some(HiddenStates::zeros( + &self.ctx, + capture_hidden_dim, + capture_tokens, + )?) + } else { + None + }; + let mut token_offset = 0usize; for i in 0..n { - let last_hidden = - self.prefill_last_hidden(prompts[i], &mut kv_states[i], recurrent_states[i])?; + let (hidden, captured) = self.prefill_chunk_forward_with_capture( + prompts[i], + &mut kv_states[i], + recurrent_states[i], + capture_layer_ids, + )?; + let last_hidden = ops::extract_vec(&self.ctx, &hidden, hidden.seq_len - 1)?; debug_assert_eq!( last_hidden.len, self.config.hidden_size, "Qwen3.5 prefill last hidden row must match request {i}" ); last_hiddens.push(last_hidden); + if let (Some(captured), Some(out)) = (captured.as_ref(), captured_all.as_mut()) { + ops::copy_hidden_token_range_into( + &self.ctx, + captured, + 0, + out, + token_offset, + prompts[i].len(), + )?; + token_offset += prompts[i].len(); + } } - self.batch_last_hidden_logits(&last_hiddens) + Ok((self.batch_last_hidden_logits(&last_hiddens)?, captured_all)) } /// Unified step: prefill new requests and decode existing requests in one call. diff --git a/openinfer-qwen35-4b/src/verify_buffers.rs b/openinfer-qwen35-4b/src/verify_buffers.rs new file mode 100644 index 00000000..39b28a2e --- /dev/null +++ b/openinfer-qwen35-4b/src/verify_buffers.rs @@ -0,0 +1,205 @@ +//! Fixed scratch for Qwen3.5 DFlash target verification. + +use anyhow::Result; +use cudarc::driver::CudaSlice; + +use crate::config::Config35; +use crate::ops::PrefillPagedPlan; +use crate::prefill_buffers::GdrChunkwiseScratch35; +use openinfer_core::tensor::{DeviceContext, HiddenStates}; + +pub(crate) struct VerifyBuffers35 { + max_batch: usize, + span: usize, + max_rows: usize, + token_ids_h: Vec, + pub(crate) token_ids_d: CudaSlice, + + pub(crate) hidden: HiddenStates, + pub(crate) hidden_next: HiddenStates, + pub(crate) normed: HiddenStates, + pub(crate) attn_results: HiddenStates, + pub(crate) hidden_mid: HiddenStates, + pub(crate) gate_up_out: HiddenStates, + pub(crate) act_out: HiddenStates, + pub(crate) mlp_out: HiddenStates, + pub(crate) logits_normed: HiddenStates, + pub(crate) logits: HiddenStates, + pub(crate) captured_hidden: HiddenStates, + + pub(crate) q_full: HiddenStates, + pub(crate) k_full: HiddenStates, + pub(crate) v_full: HiddenStates, + pub(crate) q_prepped: HiddenStates, + pub(crate) attn_out_full: HiddenStates, + + pub(crate) qkv: HiddenStates, + pub(crate) z: HiddenStates, + pub(crate) b_proj: HiddenStates, + pub(crate) a_proj: HiddenStates, + pub(crate) qkv_conv: HiddenStates, + pub(crate) gdr_out: HiddenStates, + pub(crate) normed_gated: HiddenStates, + pub(crate) compact_qkv: HiddenStates, + pub(crate) compact_b: HiddenStates, + pub(crate) compact_a: HiddenStates, + pub(crate) compact_qkv_conv: HiddenStates, + pub(crate) compact_gdr: HiddenStates, + pub(crate) gdr_scratch: GdrChunkwiseScratch35, + + pub(crate) plan: PrefillPagedPlan, + pub(crate) sample: openinfer_sample::SampleScratch, +} + +impl VerifyBuffers35 { + pub(crate) fn new( + ctx: &DeviceContext, + config: &Config35, + max_batch: usize, + span: usize, + num_capture_layers: usize, + max_total_pages: usize, + ) -> Result { + anyhow::ensure!(max_batch > 0, "Qwen3.5 verify buffers need max_batch > 0"); + anyhow::ensure!(span > 0, "Qwen3.5 verify buffers need span > 0"); + let max_rows = max_batch * span; + let hidden = config.hidden_size; + let q_proj_dim = config.full_attn_q_proj_dim(); + let q_dim = config.full_attn_q_dim(); + let kv_dim = config.full_attn_kv_dim(); + let qkv_dim = config.linear_attn_qkv_dim(); + let z_dim = config.linear_attn_z_dim(); + let group_size = config.num_attention_heads / config.num_key_value_heads; + let max_tiles = max_batch * span * group_size.max(1); + + Ok(Self { + max_batch, + span, + max_rows, + token_ids_h: vec![0; max_rows], + token_ids_d: ctx.stream.alloc_zeros(max_rows)?, + + hidden: HiddenStates::zeros(ctx, hidden, max_rows)?, + hidden_next: HiddenStates::zeros(ctx, hidden, max_rows)?, + normed: HiddenStates::zeros(ctx, hidden, max_rows)?, + attn_results: HiddenStates::zeros(ctx, hidden, max_rows)?, + hidden_mid: HiddenStates::zeros(ctx, hidden, max_rows)?, + gate_up_out: HiddenStates::zeros(ctx, 2 * config.intermediate_size, max_rows)?, + act_out: HiddenStates::zeros(ctx, config.intermediate_size, max_rows)?, + mlp_out: HiddenStates::zeros(ctx, hidden, max_rows)?, + logits_normed: HiddenStates::zeros(ctx, hidden, max_rows)?, + logits: HiddenStates::zeros(ctx, config.vocab_size, max_rows)?, + captured_hidden: HiddenStates::zeros( + ctx, + hidden * num_capture_layers.max(1), + max_rows, + )?, + + q_full: HiddenStates::zeros(ctx, q_proj_dim, max_rows)?, + k_full: HiddenStates::zeros(ctx, kv_dim, max_rows)?, + v_full: HiddenStates::zeros(ctx, kv_dim, max_rows)?, + q_prepped: HiddenStates::zeros(ctx, q_dim, max_rows)?, + attn_out_full: HiddenStates::zeros(ctx, q_dim, max_rows)?, + + qkv: HiddenStates::zeros(ctx, qkv_dim, max_rows)?, + z: HiddenStates::zeros(ctx, z_dim, max_rows)?, + b_proj: HiddenStates::zeros(ctx, config.linear_num_value_heads, max_rows)?, + a_proj: HiddenStates::zeros(ctx, config.linear_num_value_heads, max_rows)?, + qkv_conv: HiddenStates::zeros(ctx, qkv_dim, max_rows)?, + gdr_out: HiddenStates::zeros(ctx, z_dim, max_rows)?, + normed_gated: HiddenStates::zeros(ctx, z_dim, max_rows)?, + compact_qkv: HiddenStates::zeros(ctx, qkv_dim, span)?, + compact_b: HiddenStates::zeros(ctx, config.linear_num_value_heads, span)?, + compact_a: HiddenStates::zeros(ctx, config.linear_num_value_heads, span)?, + compact_qkv_conv: HiddenStates::zeros(ctx, qkv_dim, span)?, + compact_gdr: HiddenStates::zeros(ctx, z_dim, span)?, + gdr_scratch: GdrChunkwiseScratch35::new(ctx, config, max_rows)?, + + plan: PrefillPagedPlan::new_preallocated( + ctx, + max_rows, + max_total_pages, + max_batch, + max_tiles, + )?, + sample: openinfer_sample::SampleScratch::new(ctx, config.vocab_size, max_rows)?, + }) + } + + pub(crate) fn max_batch(&self) -> usize { + self.max_batch + } + + pub(crate) fn set_rows(&mut self, rows: usize) { + assert!( + rows <= self.max_rows, + "Qwen3.5 verify rows {rows} exceeds capacity {}", + self.max_rows + ); + self.hidden.seq_len = rows; + self.hidden_next.seq_len = rows; + self.normed.seq_len = rows; + self.attn_results.seq_len = rows; + self.hidden_mid.seq_len = rows; + self.gate_up_out.seq_len = rows; + self.act_out.seq_len = rows; + self.mlp_out.seq_len = rows; + self.logits_normed.seq_len = rows; + self.logits.seq_len = rows; + self.captured_hidden.seq_len = rows; + self.q_full.seq_len = rows; + self.k_full.seq_len = rows; + self.v_full.seq_len = rows; + self.q_prepped.seq_len = rows; + self.attn_out_full.seq_len = rows; + self.qkv.seq_len = rows; + self.z.seq_len = rows; + self.b_proj.seq_len = rows; + self.a_proj.seq_len = rows; + self.qkv_conv.seq_len = rows; + self.gdr_out.seq_len = rows; + self.normed_gated.seq_len = rows; + self.gdr_scratch.set_rows(rows); + } + + pub(crate) fn set_compact_rows(&mut self, rows: usize) { + assert!( + rows <= self.span, + "Qwen3.5 compact verify rows {rows} exceeds span {}", + self.span + ); + self.compact_qkv.seq_len = rows; + self.compact_b.seq_len = rows; + self.compact_a.seq_len = rows; + self.compact_qkv_conv.seq_len = rows; + self.compact_gdr.seq_len = rows; + self.gdr_scratch.set_rows(rows); + } + + pub(crate) fn stage_tokens(&mut self, ctx: &DeviceContext, spans: &[&[u32]]) -> Result { + anyhow::ensure!( + spans.len() <= self.max_batch, + "Qwen3.5 verify batch {} exceeds capacity {}", + spans.len(), + self.max_batch + ); + let total_rows: usize = spans.iter().map(|span| span.len()).sum(); + anyhow::ensure!( + total_rows <= self.max_rows, + "Qwen3.5 verify rows {total_rows} exceeds capacity {}", + self.max_rows + ); + self.token_ids_h.clear(); + self.token_ids_h.reserve(total_rows); + for span in spans { + self.token_ids_h.extend_from_slice(span); + } + self.set_rows(total_rows); + if total_rows > 0 { + let mut token_ids_d = self.token_ids_d.slice_mut(..total_rows); + ctx.stream + .memcpy_htod(&self.token_ids_h, &mut token_ids_d)?; + } + Ok(total_rows) + } +} diff --git a/openinfer-qwen35-4b/src/weights.rs b/openinfer-qwen35-4b/src/weights.rs index 67c7efb8..1f1fd8ca 100644 --- a/openinfer-qwen35-4b/src/weights.rs +++ b/openinfer-qwen35-4b/src/weights.rs @@ -94,6 +94,20 @@ impl Qwen35Model { model_path: &str, enable_cuda_graph: bool, device_ordinal: usize, + ) -> Result { + Self::from_safetensors_with_device_options_and_reservation( + model_path, + enable_cuda_graph, + device_ordinal, + None, + ) + } + + pub(crate) fn from_safetensors_with_device_options_and_reservation( + model_path: &str, + enable_cuda_graph: bool, + device_ordinal: usize, + dflash_reservation: Option<&crate::dflash::DFlashMemoryReservation>, ) -> Result { info!("Loading Qwen3.5 model from: {}", model_path); debug!("Initializing GPU"); @@ -333,13 +347,20 @@ impl Qwen35Model { let max_prefill_len = super::prefill::SCRATCH_ESTIMATE_SEQ; let scratch_reserve = super::prefill_buffers::GdrChunkwiseScratch35::estimate_bytes(&config, max_prefill_len); - let available = free_bytes.saturating_sub(scratch_reserve); + let dflash_fixed_reserve = dflash_reservation.map_or(0, |r| r.fixed_bytes); + let dflash_per_page = + dflash_reservation.map_or(0, |r| r.kv_bytes_per_token.saturating_mul(page_size)); + let effective_bytes_per_page = bytes_per_page.saturating_add(dflash_per_page); + let available = free_bytes + .saturating_sub(scratch_reserve) + .saturating_sub(dflash_fixed_reserve); let kv_budget = (available as f64 * 0.85) as usize; - let num_pages = (kv_budget / bytes_per_page).max(64); + let num_pages = (kv_budget / effective_bytes_per_page).max(64); let kv_mb = num_pages * bytes_per_page / (1024 * 1024); let scratch_mb = scratch_reserve / (1024 * 1024); + let dflash_fixed_mb = dflash_fixed_reserve / (1024 * 1024); info!( - "Qwen3.5 KV cache: {num_pages} pages ({kv_mb} MB), prefill scratch reserve: {scratch_mb} MB, {:.0}% of {:.0} MB free", + "Qwen3.5 KV cache: {num_pages} pages ({kv_mb} MB), prefill scratch reserve: {scratch_mb} MB, DFlash fixed reserve: {dflash_fixed_mb} MB, {:.0}% of {:.0} MB free", kv_budget as f64 / free_bytes as f64 * 100.0, free_bytes as f64 / 1024.0 / 1024.0 ); @@ -382,6 +403,18 @@ impl Qwen35Model { &self.ctx } + pub(crate) fn output_projection_tied(&self) -> &DeviceMatrix { + &self.embed_tokens + } + + pub(crate) fn get_embeddings_batch_into( + &self, + token_ids_gpu: &CudaSlice, + out: &mut openinfer_core::tensor::HiddenStates, + ) -> Result<()> { + crate::ops::embedding_batch(&self.ctx, &self.embed_tokens, token_ids_gpu, out) + } + pub(crate) fn alloc_kv(&self) -> openinfer_core::kv_pool::KvState { self.kv_pool.alloc() } diff --git a/openinfer-qwen35-4b/tests/dflash_speculative_gate.rs b/openinfer-qwen35-4b/tests/dflash_speculative_gate.rs new file mode 100644 index 00000000..7b14a110 --- /dev/null +++ b/openinfer-qwen35-4b/tests/dflash_speculative_gate.rs @@ -0,0 +1,514 @@ +//! Qwen3.5 DFlash scheduler losslessness gate. +//! +//! The live DFlash path must be opt-in and greedy-lossless over Qwen3.5's +//! hybrid target state: full-attention KV plus recurrent and convolution state. +//! Exact token equality is the first preference. At a divergence we use the +//! same regret rule as the Qwen3 gate: the speculative token must be either the +//! prefill-kernel greedy pick for the shared context or sit within a small +//! near-tie band in that prefill distribution. This avoids false failures from +//! the known decode-vs-prefill bf16 tie boundary while still catching real +//! verify/commit/capture bugs. + +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use openinfer_core::engine::{ + EngineHandle, EngineLoadOptions, GenerateRequest, TokenEvent, TokenSink, +}; +use openinfer_core::sampler::SamplingParams; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use vllm_text::tokenizer::DynTokenizer; + +mod common; + +const MODEL_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3.5-4B"); +const DRAFT_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3.5-4B-DFlash"); +const LOGPROBS: usize = 20; +const MARGIN_TOL: f32 = 0.20; +const MAX_BATCH: usize = 16; +const SYNTHETIC_TOKEN_LO: u32 = 100; +const SYNTHETIC_TOKEN_HI: u32 = 100_000; + +static GPU: std::sync::Mutex<()> = std::sync::Mutex::new(()); + +struct Step { + id: u32, + top_logprobs: Vec<(u32, f32)>, +} + +fn target_path_or_skip() -> Option { + match std::env::var("OPENINFER_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(MODEL_PATH).join("config.json").exists() => { + Some(MODEL_PATH.to_string()) + } + Err(_) => { + eprintln!( + "skipping qwen35 DFlash gate: {MODEL_PATH}/config.json missing; set OPENINFER_TEST_MODEL_PATH" + ); + None + } + } +} + +fn draft_path_or_skip() -> Option { + match std::env::var("OPENINFER_DFLASH_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(DRAFT_PATH).join("config.json").exists() => { + Some(DRAFT_PATH.to_string()) + } + Err(_) => { + eprintln!( + "skipping qwen35 DFlash gate: {DRAFT_PATH}/config.json missing; set OPENINFER_DFLASH_TEST_MODEL_PATH" + ); + None + } + } +} + +fn engine_options() -> EngineLoadOptions { + EngineLoadOptions { + enable_cuda_graph: true, + enable_prefill_profile: false, + device_ordinals: vec![0], + seed: 42, + ..EngineLoadOptions::default() + } +} + +fn launch(model_path: &str, draft_path: Option) -> EngineHandle { + openinfer_qwen35_4b::start_engine_with_capacity_and_dflash( + Path::new(model_path), + engine_options(), + MAX_BATCH, + openinfer_qwen35_4b::DEFAULT_MAX_PREFILL_TOKENS, + draft_path, + ) + .expect("failed to start Qwen3.5 engine") +} + +fn greedy_params() -> SamplingParams { + SamplingParams { + ignore_eos: true, + ..SamplingParams::default() + } +} + +fn generate( + handle: &EngineHandle, + prompt_tokens: Vec, + logprobs: usize, + max_tokens: usize, +) -> Vec { + let (token_tx, mut rx) = TokenSink::standalone(); + handle + .submit(GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens, + params: greedy_params(), + max_tokens, + lora_adapter: None, + token_tx, + logprobs, + echo: false, + }) + .expect("submit failed"); + + collect_steps(&mut rx) +} + +fn generate_concurrent(handle: &EngineHandle, requests: Vec<(Vec, usize)>) -> Vec> { + let receivers: Vec<_> = requests + .into_iter() + .map(|(prompt_tokens, max_tokens)| { + let (token_tx, rx) = TokenSink::standalone(); + handle + .submit(GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens, + params: greedy_params(), + max_tokens, + lora_adapter: None, + token_tx, + logprobs: 0, + echo: false, + }) + .expect("submit failed"); + rx + }) + .collect(); + + receivers + .into_iter() + .map(|mut rx| collect_steps(&mut rx)) + .collect() +} + +fn collect_steps(rx: &mut openinfer_core::engine::TokenStreamReceiver) -> Vec { + let mut steps = Vec::new(); + loop { + match rx.blocking_recv().map(|(_, event)| event) { + Some(TokenEvent::Token { id, logprob }) => steps.push(Step { + id, + top_logprobs: logprob.map(|lp| lp.top_logprobs).unwrap_or_default(), + }), + Some(TokenEvent::PromptTokens { .. } | TokenEvent::Scheduled { .. }) => {} + Some(TokenEvent::Finished { .. }) => return steps, + Some(TokenEvent::Error { message, .. }) => panic!("generation failed: {message}"), + Some(TokenEvent::Rejected { message, .. }) => panic!("generation rejected: {message}"), + None => panic!("scheduler channel closed without Finished"), + } + } +} + +fn prefill_next(handle: &EngineHandle, context: Vec) -> Step { + let (token_tx, mut rx) = TokenSink::standalone(); + handle + .submit(GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens: context, + params: greedy_params(), + max_tokens: 1, + lora_adapter: None, + token_tx, + logprobs: LOGPROBS, + echo: true, + }) + .expect("submit failed"); + + loop { + match rx.blocking_recv().map(|(_, event)| event) { + Some(TokenEvent::Token { id, logprob }) => { + return Step { + id, + top_logprobs: logprob.map(|lp| lp.top_logprobs).unwrap_or_default(), + }; + } + Some(TokenEvent::PromptTokens { .. } | TokenEvent::Scheduled { .. }) => {} + Some(TokenEvent::Finished { .. }) => panic!("prefill_next finished without a token"), + Some(TokenEvent::Error { message, .. }) => panic!("prefill_next failed: {message}"), + Some(TokenEvent::Rejected { message, .. }) => { + panic!("prefill_next rejected: {message}") + } + None => panic!("scheduler channel closed without prefill_next token"), + } + } +} + +fn check_lossless( + handle: &EngineHandle, + tokenizer: &DynTokenizer, + label: &str, + prompt_tokens: &[u32], + base: &[Step], + spec: &[Step], +) -> Result<(), String> { + let matched = base + .iter() + .zip(spec.iter()) + .take_while(|(b, s)| b.id == s.id) + .count(); + if matched == base.len().min(spec.len()) { + eprintln!("{label}: {matched}/{} tokens identical", base.len()); + return Ok(()); + } + if matched >= spec.len() { + return Err(format!( + "{label}: speculative output ended before baseline at token {matched}" + )); + } + if base[matched].top_logprobs.is_empty() { + return Err(format!( + "{label}: missing baseline logprobs at first divergence {matched}" + )); + } + + let spec_id = spec[matched].id; + let decode_argmax = base[matched].top_logprobs[0].0; + let mut context = prompt_tokens.to_vec(); + context.extend(base[..matched].iter().map(|step| step.id)); + let prefill_ref = prefill_next(handle, context); + + if prefill_ref.id == spec_id { + eprintln!( + "{label}: prefill/decode kernel-gap flip at token {matched}; spec matches prefill greedy" + ); + return Ok(()); + } + + let prefill_regret = prefill_ref + .top_logprobs + .iter() + .find(|(token, _)| *token == spec_id) + .map(|(_, lp)| prefill_ref.top_logprobs[0].1 - lp); + if prefill_regret.is_some_and(|regret| regret <= MARGIN_TOL) { + eprintln!( + "{label}: near-tie at token {matched}; regret {:.3} <= {MARGIN_TOL}", + prefill_regret.unwrap() + ); + return Ok(()); + } + + let decode_regret = base[matched] + .top_logprobs + .iter() + .find(|(token, _)| *token == spec_id) + .map(|(_, lp)| base[matched].top_logprobs[0].1 - lp); + let lo = matched.saturating_sub(2); + let hi = (matched + 4).min(base.len()).min(spec.len()); + let base_ids: Vec = base[lo..hi].iter().map(|step| step.id).collect(); + let spec_ids: Vec = spec[lo..hi].iter().map(|step| step.id).collect(); + Err(format!( + "{label}: real divergence at token {matched}: spec={spec_id}, prefill_argmax={}, decode_argmax={decode_argmax}, prefill_regret={prefill_regret:?}, decode_regret={decode_regret:?}, base_window={:?} ({:?}), spec_window={:?} ({:?})", + prefill_ref.id, + base_ids, + tokenizer.decode(&base_ids, false).unwrap_or_default(), + spec_ids, + tokenizer.decode(&spec_ids, false).unwrap_or_default(), + )) +} + +fn long_prompt(seed: &str) -> String { + format!( + "{seed}\n{}\n{}\n{}\n{}\n{}\n{}", + "Explain the implementation carefully, include state ownership, scheduler batching, and why every accepted token must be target-verified.", + "Use compact technical prose with concrete examples and avoid changing topic.", + "Then continue with a deterministic continuation that has enough context for a draft model to use hidden-state features.", + "Repeat the key point: KV cache, recurrent state, and convolution state must move together.", + "Describe the rollback path, the replay path, the fixed verification buffers, and the reason heterogeneous output budgets can shorten verify spans.", + "Finally, restate the same mechanism in a second paragraph with slightly different wording so the prompt is long enough to exercise DFlash capture." + ) +} + +fn synthetic_random_prompt(len: usize, seed: u64, request_idx: usize) -> Vec { + let mut rng = + StdRng::seed_from_u64(seed ^ (request_idx as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)); + (0..len) + .map(|_| rng.random_range(SYNTHETIC_TOKEN_LO..SYNTHETIC_TOKEN_HI)) + .collect() +} + +#[test] +fn qwen35_dflash_single_and_concurrent_greedy_are_lossless() { + let (Some(model_path), Some(draft_path)) = (target_path_or_skip(), draft_path_or_skip()) else { + return; + }; + let _gpu = GPU.lock().unwrap_or_else(|p| p.into_inner()); + let tokenizer = common::load_tokenizer(&model_path); + + let cases: Vec<(String, usize)> = vec![ + ( + long_prompt("Write a Rust function for paged attention."), + 48, + ), + (long_prompt("Summarize a GPU scheduler benchmark."), 32), + ( + long_prompt("Explain speculative decoding for a hybrid model."), + 40, + ), + ( + long_prompt("Draft a short guide for CUDA Graph verification."), + 24, + ), + ]; + let encoded: Vec> = cases + .iter() + .map(|(prompt, _)| tokenizer.encode(prompt, false).expect("encode failed")) + .collect(); + for (idx, tokens) in encoded.iter().enumerate() { + assert!( + tokens.len() >= 128, + "case {idx} prompt must exceed the DFlash capture threshold, got {} tokens", + tokens.len() + ); + } + + let baselines: Vec> = { + let handle = launch(&model_path, None); + let out = encoded + .iter() + .zip(cases.iter()) + .map(|(tokens, (_, max_tokens))| { + generate(&handle, tokens.clone(), LOGPROBS, *max_tokens) + }) + .collect(); + drop(handle); + std::thread::sleep(Duration::from_secs(2)); + out + }; + + let handle = launch(&model_path, Some(PathBuf::from(&draft_path))); + let single_spec = generate(&handle, encoded[0].clone(), 0, cases[0].1); + let concurrent_specs = generate_concurrent( + &handle, + encoded + .iter() + .zip(cases.iter()) + .map(|(tokens, (_, max_tokens))| (tokens.clone(), *max_tokens)) + .collect(), + ); + + let mut failures = Vec::new(); + if let Err(err) = check_lossless( + &handle, + &tokenizer, + "single", + &encoded[0], + &baselines[0], + &single_spec, + ) { + failures.push(err); + } + for (idx, spec) in concurrent_specs.iter().enumerate() { + if let Err(err) = check_lossless( + &handle, + &tokenizer, + &format!("concurrent-{idx}"), + &encoded[idx], + &baselines[idx], + spec, + ) { + failures.push(err); + } + } + drop(handle); + + assert!( + failures.is_empty(), + "Qwen3.5 DFlash speculative decode is not lossless:\n{}", + failures.join("\n") + ); +} + +#[test] +fn qwen35_dflash_short_prompt_concurrent_random_is_within_oracle() { + let (Some(model_path), Some(draft_path)) = (target_path_or_skip(), draft_path_or_skip()) else { + return; + }; + let _gpu = GPU.lock().unwrap_or_else(|p| p.into_inner()); + let tokenizer = common::load_tokenizer(&model_path); + let output_len = 256; + let prompts: Vec> = (0..MAX_BATCH) + .map(|idx| synthetic_random_prompt(1, 0, idx)) + .collect(); + + let baselines: Vec> = { + let handle = launch(&model_path, None); + let out = prompts + .iter() + .map(|tokens| generate(&handle, tokens.clone(), LOGPROBS, output_len)) + .collect(); + drop(handle); + std::thread::sleep(Duration::from_secs(2)); + out + }; + + let handle = launch(&model_path, Some(PathBuf::from(&draft_path))); + let specs = generate_concurrent( + &handle, + prompts + .iter() + .map(|tokens| (tokens.clone(), output_len)) + .collect(), + ); + + let mut failures = Vec::new(); + for (idx, spec) in specs.iter().enumerate() { + if let Err(err) = check_lossless( + &handle, + &tokenizer, + &format!("short-random-c{MAX_BATCH}-{idx}"), + &prompts[idx], + &baselines[idx], + spec, + ) { + failures.push(err); + } + } + drop(handle); + + assert!( + failures.is_empty(), + "Qwen3.5 DFlash short-prompt concurrent decode is outside the oracle:\n{}", + failures.join("\n") + ); +} + +fn check_random_concurrent_case( + model_path: &str, + draft_path: &str, + tokenizer: &DynTokenizer, + prompt_len: usize, + concurrency: usize, +) -> Vec { + let output_len = 256; + let prompts: Vec> = (0..concurrency) + .map(|idx| synthetic_random_prompt(prompt_len, 42, idx)) + .collect(); + + let baselines: Vec> = { + let handle = launch(model_path, None); + let out = prompts + .iter() + .map(|tokens| generate(&handle, tokens.clone(), LOGPROBS, output_len)) + .collect(); + drop(handle); + std::thread::sleep(Duration::from_secs(2)); + out + }; + + let handle = launch(model_path, Some(PathBuf::from(draft_path))); + let specs = generate_concurrent( + &handle, + prompts + .iter() + .map(|tokens| (tokens.clone(), output_len)) + .collect(), + ); + + let mut failures = Vec::new(); + for (idx, spec) in specs.iter().enumerate() { + if let Err(err) = check_lossless( + &handle, + tokenizer, + &format!("bench-random-p{prompt_len}-c{concurrency}-{idx}"), + &prompts[idx], + &baselines[idx], + spec, + ) { + failures.push(err); + } + } + drop(handle); + failures +} + +#[test] +fn qwen35_dflash_benchmark_random_concurrency_is_within_oracle() { + let (Some(model_path), Some(draft_path)) = (target_path_or_skip(), draft_path_or_skip()) else { + return; + }; + let _gpu = GPU.lock().unwrap_or_else(|p| p.into_inner()); + let tokenizer = common::load_tokenizer(&model_path); + + let mut failures = Vec::new(); + for (prompt_len, concurrency) in [(1024, 16), (4096, 8), (4096, 16)] { + failures.extend(check_random_concurrent_case( + &model_path, + &draft_path, + &tokenizer, + prompt_len, + concurrency, + )); + } + + assert!( + failures.is_empty(), + "Qwen3.5 DFlash benchmark-shaped concurrent decode is outside the oracle:\n{}", + failures.join("\n") + ); +} diff --git a/openinfer-qwen35-4b/tests/dflash_speculative_perf.rs b/openinfer-qwen35-4b/tests/dflash_speculative_perf.rs new file mode 100644 index 00000000..9813ae06 --- /dev/null +++ b/openinfer-qwen35-4b/tests/dflash_speculative_perf.rs @@ -0,0 +1,162 @@ +//! DFlash speculative-decoding single-stream latency A/B for Qwen3.5. +//! +//! This mirrors the Qwen3 DFlash perf harness: fixed 256-token greedy decode, +//! speculative OFF vs ON, same prompts and hardware, one warm-up discarded. +//! Qwen3.5 is a hybrid 24-linear + 8-full-attention model, so this harness is +//! the explicit evidence source for the single-stream boundary instead of +//! inferring it from the concurrent throughput sweep. +//! +//! Requires CUDA, Qwen3.5 target weights, and the Qwen3.5 DFlash drafter. Set +//! `OPENINFER_TEST_MODEL_PATH` and `OPENINFER_DFLASH_TEST_MODEL_PATH`; skips +//! when either model is unavailable. Use `--nocapture` to read the numbers. + +use std::path::{Path, PathBuf}; +use std::time::{Duration, Instant}; + +use openinfer_core::engine::{ + EngineHandle, EngineLoadOptions, GenerateRequest, TokenEvent, TokenSink, +}; +use openinfer_core::sampler::SamplingParams; + +mod common; + +const MODEL_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3.5-4B"); +const DRAFT_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3.5-4B-DFlash"); +const GENERATED_TOKENS: usize = 256; +const MAX_BATCH: usize = 16; + +static GPU: std::sync::Mutex<()> = std::sync::Mutex::new(()); + +fn target_path_or_skip() -> Option { + match std::env::var("OPENINFER_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(MODEL_PATH).join("config.json").exists() => { + Some(MODEL_PATH.to_string()) + } + Err(_) => { + eprintln!("skipping qwen35 DFlash perf A/B: set OPENINFER_TEST_MODEL_PATH"); + None + } + } +} + +fn draft_path_or_skip() -> Option { + match std::env::var("OPENINFER_DFLASH_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(DRAFT_PATH).join("config.json").exists() => { + Some(DRAFT_PATH.to_string()) + } + Err(_) => { + eprintln!("skipping qwen35 DFlash perf A/B: set OPENINFER_DFLASH_TEST_MODEL_PATH"); + None + } + } +} + +fn engine_options() -> EngineLoadOptions { + EngineLoadOptions { + enable_cuda_graph: true, + enable_prefill_profile: false, + device_ordinals: vec![0], + seed: 42, + ..EngineLoadOptions::default() + } +} + +fn launch(model_path: &str, draft_path: Option) -> EngineHandle { + openinfer_qwen35_4b::start_engine_with_capacity_and_dflash( + Path::new(model_path), + engine_options(), + MAX_BATCH, + openinfer_qwen35_4b::DEFAULT_MAX_PREFILL_TOKENS, + draft_path, + ) + .expect("failed to start Qwen3.5 engine") +} + +fn timed_generate(handle: &EngineHandle, prompt_tokens: Vec) -> (usize, Duration) { + let (token_tx, mut rx) = TokenSink::standalone(); + let start = Instant::now(); + handle + .submit(GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens, + params: SamplingParams { + ignore_eos: true, + ..SamplingParams::default() + }, + max_tokens: GENERATED_TOKENS, + lora_adapter: None, + token_tx, + logprobs: 0, + echo: false, + }) + .expect("submit failed"); + + let mut count = 0usize; + loop { + match rx.blocking_recv().map(|(_, event)| event) { + Some(TokenEvent::Token { .. }) => count += 1, + Some(TokenEvent::Scheduled { .. } | TokenEvent::PromptTokens { .. }) => {} + Some(TokenEvent::Finished { .. }) => return (count, start.elapsed()), + Some(TokenEvent::Error { message, .. }) => panic!("generation failed: {message}"), + Some(TokenEvent::Rejected { message, .. }) => panic!("generation rejected: {message}"), + None => panic!("scheduler channel closed without Finished"), + } + } +} + +fn measure(handle: &EngineHandle, prompts: &[Vec]) -> f64 { + let _ = timed_generate(handle, prompts[0].clone()); + let mut tokens = 0usize; + let mut elapsed = Duration::ZERO; + for prompt in prompts { + let (n, dt) = timed_generate(handle, prompt.clone()); + tokens += n; + elapsed += dt; + } + tokens as f64 / elapsed.as_secs_f64() +} + +#[test] +fn qwen35_dflash_single_stream_speedup() { + let (Some(model_path), Some(draft_path)) = (target_path_or_skip(), draft_path_or_skip()) else { + return; + }; + let _gpu = GPU.lock().unwrap_or_else(|p| p.into_inner()); + let tokenizer = common::load_tokenizer(&model_path); + let prompts: Vec> = [ + "Write a compact explanation of Qwen3.5 hybrid recurrent state.", + "Explain speculative decoding for a model with recurrent and KV state.", + "List the tradeoffs of batched verification in a serving scheduler.", + ] + .iter() + .map(|prompt| tokenizer.encode(prompt, false).expect("encode failed")) + .collect(); + + let baseline_tps = { + let handle = launch(&model_path, None); + let tps = measure(&handle, &prompts); + drop(handle); + std::thread::sleep(Duration::from_secs(2)); + tps + }; + + let spec_tps = { + let handle = launch(&model_path, Some(PathBuf::from(&draft_path))); + measure(&handle, &prompts) + }; + + let speedup = spec_tps / baseline_tps; + eprintln!("──────── Qwen3.5 DFlash single-stream decode A/B (bs=1) ────────"); + eprintln!(" spec OFF (plain decode): {baseline_tps:7.1} tok/s"); + eprintln!(" spec ON (DFlash): {spec_tps:7.1} tok/s"); + eprintln!(" speedup: {speedup:7.2}×"); + eprintln!("────────────────────────────────────────────────────────────────"); + + assert!( + speedup > 0.8, + "Qwen3.5 DFlash single-stream is catastrophically slower ({speedup:.2}×)" + ); +} diff --git a/openinfer-qwen35-4b/tests/speculative_verify.rs b/openinfer-qwen35-4b/tests/speculative_verify.rs new file mode 100644 index 00000000..59e9de89 --- /dev/null +++ b/openinfer-qwen35-4b/tests/speculative_verify.rs @@ -0,0 +1,289 @@ +use std::path::Path; + +use openinfer_qwen35_4b::runtime::{ + DecodePlan, DecodeStepItem, PrefillPlan, PrefillStepItem, Qwen35Executor, RequestId, + VerifiedToken, VerifyPlan, VerifyStepItem, +}; + +const MODEL_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../models/Qwen3.5-4B"); +const LOGPROBS: usize = 1; + +#[derive(Clone)] +struct CaseSpec { + request_id: RequestId, + prompt_tokens: Vec, + draft_len: usize, + reject_at: Option, +} + +#[derive(Clone)] +struct CaseExpectation { + request_id: RequestId, + first_token: u32, + draft_tokens: Vec, + expected_matched: usize, + expected_accepted: Vec, + followup_token: u32, +} + +fn model_path_or_skip() -> Option { + match std::env::var("OPENINFER_TEST_MODEL_PATH") { + Ok(path) => Some(path), + Err(_) if Path::new(MODEL_PATH).join("config.json").exists() => { + Some(MODEL_PATH.to_string()) + } + Err(_) => { + eprintln!( + "skipping qwen35 speculative_verify: {MODEL_PATH}/config.json is missing; set OPENINFER_TEST_MODEL_PATH to run it" + ); + None + } + } +} + +fn build_executor(model_path: &str, capacity: usize) -> Qwen35Executor { + let capacity = [1usize, 2, 4, 8, 16, 32, 64] + .into_iter() + .find(|bucket| *bucket >= capacity) + .expect("test batch exceeds Qwen3.5 decode bucket capacity"); + Qwen35Executor::from_runtime_with_capacity(model_path, false, &[0], capacity) + .expect("load Qwen3.5 executor") +} + +fn prefill(exec: &mut Qwen35Executor, cases: &[CaseSpec]) -> Vec { + let reqs: Vec<_> = cases + .iter() + .map(|case| PrefillStepItem::new(case.request_id, case.prompt_tokens.clone(), LOGPROBS)) + .collect(); + exec.execute_prefill(PrefillPlan { requests: &reqs }) + .expect("prefill") + .requests + .into_iter() + .map(|result| result.first_token) + .collect() +} + +fn decode_once(exec: &mut Qwen35Executor, tokens: &[u32], cases: &[CaseSpec]) -> Vec { + let reqs: Vec<_> = cases + .iter() + .zip(tokens.iter()) + .map(|(case, &token)| DecodeStepItem::new(case.request_id, token, LOGPROBS)) + .collect(); + exec.execute_decode(DecodePlan { requests: &reqs }) + .expect("decode") + .requests + .into_iter() + .map(|result| result.token) + .collect() +} + +fn build_expectations(model_path: &str, cases: &[CaseSpec]) -> Vec { + let mut exec = build_executor(model_path, cases.len()); + let first_tokens = prefill(&mut exec, cases); + let max_len = cases + .iter() + .map(|case| case.draft_len + 3) + .max() + .expect("at least one case"); + let mut generated: Vec> = first_tokens.into_iter().map(|token| vec![token]).collect(); + while generated.iter().any(|tokens| tokens.len() < max_len) { + let fed: Vec = generated + .iter() + .map(|tokens| *tokens.last().expect("prefill token")) + .collect(); + for (tokens, next) in generated + .iter_mut() + .zip(decode_once(&mut exec, &fed, cases).into_iter()) + { + tokens.push(next); + } + } + + cases + .iter() + .zip(generated.iter()) + .map(|(case, generated)| { + let first = generated[0]; + + let mut draft_tokens = generated[1..=case.draft_len].to_vec(); + let mut expected_matched = case.draft_len; + if let Some(reject_at) = case.reject_at { + draft_tokens[reject_at] = draft_tokens[reject_at].wrapping_add(17); + expected_matched = reject_at; + } + let mut accepted_ids = draft_tokens[..expected_matched].to_vec(); + accepted_ids.push(generated[expected_matched + 1]); + + let mut expected_accepted = Vec::with_capacity(accepted_ids.len()); + for expected in &accepted_ids { + expected_accepted.push(VerifiedToken { + token: *expected, + logprob: None, + }); + } + let followup_token = generated[expected_matched + 2]; + + CaseExpectation { + request_id: case.request_id, + first_token: first, + draft_tokens, + expected_matched, + expected_accepted, + followup_token, + } + }) + .collect() +} + +fn run_speculative_case(model_path: &str, cases: Vec) { + let expectations = build_expectations(model_path, &cases); + let mut exec = build_executor(model_path, cases.len()); + let first_tokens = prefill(&mut exec, &cases); + assert_eq!( + first_tokens, + expectations + .iter() + .map(|expect| expect.first_token) + .collect::>() + ); + let before_state = exec.debug_state_summary(); + + let verify_items: Vec<_> = expectations + .iter() + .map(|expect| { + let mut token_ids = Vec::with_capacity(expect.draft_tokens.len() + 1); + token_ids.push(expect.first_token); + token_ids.extend_from_slice(&expect.draft_tokens); + VerifyStepItem::new(expect.request_id, token_ids, LOGPROBS) + }) + .collect(); + let result = exec + .execute_speculative_verify(VerifyPlan { + requests: &verify_items, + }) + .expect("speculative verify"); + let after_state = exec.debug_state_summary(); + + assert_eq!(result.requests.len(), expectations.len()); + for (row, expect) in result.requests.iter().zip(expectations.iter()) { + assert_eq!(row.request_id, expect.request_id); + assert_eq!(row.matched_draft_tokens, expect.expected_matched); + assert_eq!( + row.accepted_tokens + .iter() + .map(|token| token.token) + .collect::>(), + expect + .expected_accepted + .iter() + .map(|token| token.token) + .collect::>() + ); + } + for ((before, after), expect) in before_state + .iter() + .zip(after_state.iter()) + .zip(expectations.iter()) + { + assert_eq!(before.request_id, after.request_id); + assert_eq!( + before.kv_seq_len + expect.expected_accepted.len(), + after.kv_seq_len + ); + assert_eq!( + before.recurrent_seq_len + expect.expected_accepted.len(), + after.recurrent_seq_len + ); + } + + let last_tokens: Vec = result + .requests + .iter() + .map(|row| row.accepted_tokens.last().expect("accepted token").token) + .collect(); + let followup_reqs: Vec<_> = cases + .iter() + .zip(last_tokens.iter()) + .map(|(case, &token)| DecodeStepItem::new(case.request_id, token, LOGPROBS)) + .collect(); + let followup = exec + .execute_decode(DecodePlan { + requests: &followup_reqs, + }) + .expect("post-spec followup") + .requests; + for (actual, expect) in followup.iter().zip(expectations.iter()) { + assert_eq!(actual.request_id, expect.request_id); + assert_eq!(actual.token, expect.followup_token); + } +} + +#[test] +fn qwen35_speculative_accept_all_state_matches_plain_decode() { + let Some(model_path) = model_path_or_skip() else { + return; + }; + run_speculative_case( + &model_path, + vec![CaseSpec { + request_id: RequestId::new(1), + prompt_tokens: vec![9707], + draft_len: 3, + reject_at: None, + }], + ); +} + +#[test] +fn qwen35_speculative_accept_prefix_and_reject_first() { + let Some(model_path) = model_path_or_skip() else { + return; + }; + run_speculative_case( + &model_path, + vec![ + CaseSpec { + request_id: RequestId::new(1), + prompt_tokens: vec![3838, 374, 220, 17, 10, 17], + draft_len: 4, + reject_at: Some(2), + }, + CaseSpec { + request_id: RequestId::new(2), + prompt_tokens: vec![9707], + draft_len: 4, + reject_at: Some(0), + }, + ], + ); +} + +#[test] +fn qwen35_speculative_mixed_batch_state_matches_plain_decode() { + let Some(model_path) = model_path_or_skip() else { + return; + }; + run_speculative_case( + &model_path, + vec![ + CaseSpec { + request_id: RequestId::new(1), + prompt_tokens: vec![9707], + draft_len: 2, + reject_at: None, + }, + CaseSpec { + request_id: RequestId::new(2), + prompt_tokens: vec![3838, 374, 220, 17, 10, 17], + draft_len: 4, + reject_at: Some(1), + }, + CaseSpec { + request_id: RequestId::new(3), + prompt_tokens: vec![785, 9282, 374, 3565], + draft_len: 3, + reject_at: Some(0), + }, + ], + ); +} diff --git a/openinfer-server/src/bin/bench_serving/cli.rs b/openinfer-server/src/bin/bench_serving/cli.rs index 77a182e9..3b15ea93 100644 --- a/openinfer-server/src/bin/bench_serving/cli.rs +++ b/openinfer-server/src/bin/bench_serving/cli.rs @@ -154,6 +154,10 @@ pub(crate) struct Cli { #[arg(long)] pub(crate) max_prefill_tokens: Option, + /// Enable Qwen3/Qwen3.5 DFlash speculative decoding with this drafter model path. + #[arg(long = "dflash-draft-model-path")] + pub(crate) dflash_draft_model_path: Option, + #[command(subcommand)] pub(crate) command: Command, } diff --git a/openinfer-server/src/bin/bench_serving/exec.rs b/openinfer-server/src/bin/bench_serving/exec.rs index ad93a982..aa110a99 100644 --- a/openinfer-server/src/bin/bench_serving/exec.rs +++ b/openinfer-server/src/bin/bench_serving/exec.rs @@ -9,6 +9,7 @@ use rand::rngs::StdRng; use openinfer::sampler::SamplingParams; use openinfer::scheduler::{SchedulerHandle, SchedulerRequest, TokenEvent, TokenSink}; +use openinfer_core::engine::TokenStreamReceiver; pub(crate) struct GenTimings { pub(crate) ttft: Duration, @@ -155,6 +156,57 @@ pub(crate) fn run_scheduler_stream( } } +fn drain_timed_scheduler_stream( + mut token_rx: TokenStreamReceiver, + start: Instant, + max_new_tokens: usize, +) -> Result { + let mut first_at: Option = None; + let mut prev_at: Option = None; + let mut emitted_tokens = 0usize; + let mut tbt = Vec::with_capacity(max_new_tokens.saturating_sub(1)); + let mut generated_tokens = Vec::with_capacity(max_new_tokens); + + loop { + match token_rx.blocking_recv().map(|(_, event)| event) { + Some(TokenEvent::Token { id, .. }) => { + let now = Instant::now(); + emitted_tokens += 1; + generated_tokens.push(id); + if first_at.is_none() { + first_at = Some(now); + } else if let Some(prev) = prev_at { + tbt.push(now - prev); + } + prev_at = Some(now); + } + Some(TokenEvent::PromptTokens { .. } | TokenEvent::Scheduled { .. }) => {} + Some(TokenEvent::Finished { .. }) => { + let total = start.elapsed(); + let ttft = first_at.map_or(total, |t| t - start); + let decode_tokens_for_rate = emitted_tokens.saturating_sub(1); + let decode_time_for_rate = tbt.iter().copied().sum(); + return Ok(GenTimings { + ttft, + tbt, + total, + emitted_tokens, + generated_tokens, + decode_tokens_for_rate, + decode_time_for_rate, + }); + } + Some(TokenEvent::Error { message, .. }) => { + anyhow::bail!("scheduler request failed: {message}"); + } + Some(TokenEvent::Rejected { message, .. }) => { + anyhow::bail!("scheduler request rejected: {message}"); + } + None => anyhow::bail!("scheduler channel closed"), + } + } +} + pub(crate) struct SchedulerBenchModel { pub(crate) handle: SchedulerHandle, } @@ -186,27 +238,31 @@ impl BenchModel for SchedulerBenchModel { ) -> Vec { let mut workers = Vec::with_capacity(prompts.len()); for (idx, prompt) in prompts.iter().enumerate() { - let handle = self.handle.clone(); - let prompt_tokens = prompt.clone(); - let sampling = *sampling; - workers.push(thread::spawn(move || { - run_timed(&prompt_tokens, max_new_tokens, |toks, n, cb| { - run_scheduler_stream( - &handle, - Some(format!("bench-serving-{idx}")), - toks.to_vec(), - sampling, - n, - |id| cb(id), - )?; - Ok(()) + let (token_tx, token_rx) = TokenSink::standalone(); + let start = Instant::now(); + let worker = thread::spawn(move || { + drain_timed_scheduler_stream(token_rx, start, max_new_tokens) + .expect("generation failed") + }); + self.handle + .submit(SchedulerRequest { + request_id: Some(format!("bench-serving-{idx}")), + queued_at_unix_s: None, + prompt_tokens: prompt.clone(), + params: *sampling, + max_tokens: max_new_tokens, + lora_adapter: None, + token_tx, + logprobs: 0, + echo: false, }) - })); + .expect("scheduler submit failed"); + workers.push(worker); } workers .into_iter() - .map(|worker| worker.join().expect("bench request worker panicked")) + .map(|worker| worker.join().expect("bench drain worker panicked")) .collect() } } diff --git a/openinfer-server/src/bin/bench_serving/main.rs b/openinfer-server/src/bin/bench_serving/main.rs index ff2bc3ba..e6459c60 100644 --- a/openinfer-server/src/bin/bench_serving/main.rs +++ b/openinfer-server/src/bin/bench_serving/main.rs @@ -21,6 +21,7 @@ )] use std::path::Path; +use std::path::PathBuf; use std::time::Instant; use anyhow::{Context, Result}; @@ -195,6 +196,7 @@ fn main() -> Result<()> { .max_prefill_tokens .filter(|&v| v > 0) .unwrap_or(openinfer_qwen3::DEFAULT_MAX_PREFILL_TOKENS); + let dflash_draft_model_path = cli.dflash_draft_model_path.as_ref().map(PathBuf::from); let handle = openinfer_qwen3::start_engine_with_offload( Path::new(&cli.model_path), EngineLoadOptions { @@ -211,7 +213,7 @@ fn main() -> Result<()> { openinfer_qwen3::Qwen3MemoryOptions::default(), openinfer_qwen3::DecodeOverlap::Off, false, - None, + dflash_draft_model_path.as_deref(), false, )?; finish(handle, cli.cuda_graph) @@ -224,7 +226,8 @@ fn main() -> Result<()> { .max_prefill_tokens .filter(|&v| v > 0) .unwrap_or(openinfer_qwen35_4b::DEFAULT_MAX_PREFILL_TOKENS); - let handle = openinfer_qwen35_4b::start_engine_with_capacity( + let dflash_draft_model_path = cli.dflash_draft_model_path.as_ref().map(PathBuf::from); + let handle = openinfer_qwen35_4b::start_engine_with_capacity_and_dflash( Path::new(&cli.model_path), EngineLoadOptions { enable_cuda_graph: cli.cuda_graph, @@ -234,8 +237,9 @@ fn main() -> Result<()> { ep_backend: EpBackend::Nccl, seed: command_seed(&cli), }, - 4, + openinfer_qwen35_4b::runtime::MAX_BATCH, max_prefill_tokens, + dflash_draft_model_path, )?; finish(handle, cli.cuda_graph) } @@ -330,6 +334,7 @@ mod tests { let metrics = build_request_metrics(&timings); assert_eq!(metrics.steady_tpot_ms.unwrap().p50_ms, 18.0); + assert_eq!(metrics.effective_tpot_ms.unwrap().p50_ms, 9.25); assert!( metrics.decode_tok_s.unwrap() > 100.0, "batched decode tok/s should use one shared step duration instead of duplicating it per row" diff --git a/openinfer-server/src/bin/bench_serving/metrics.rs b/openinfer-server/src/bin/bench_serving/metrics.rs index 98a1ab02..934d7801 100644 --- a/openinfer-server/src/bin/bench_serving/metrics.rs +++ b/openinfer-server/src/bin/bench_serving/metrics.rs @@ -66,9 +66,13 @@ pub(crate) fn generated_token_hash(tokens: &[u32]) -> String { } pub(crate) fn generated_token_trace(tokens: &[u32]) -> GeneratedTokenTrace { + let full = std::env::var_os("OPENINFER_BENCH_FULL_TOKEN_TRACE") + .is_some() + .then(|| tokens.to_vec()); GeneratedTokenTrace { hash: generated_token_hash(tokens), prefix: tokens.iter().copied().take(16).collect(), + full, len: tokens.len(), } } diff --git a/openinfer-server/src/bin/bench_serving/report.rs b/openinfer-server/src/bin/bench_serving/report.rs index c41c91b4..b8060b45 100644 --- a/openinfer-server/src/bin/bench_serving/report.rs +++ b/openinfer-server/src/bin/bench_serving/report.rs @@ -41,6 +41,8 @@ pub(crate) struct CountStats { pub(crate) struct GeneratedTokenTrace { pub(crate) hash: String, pub(crate) prefix: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub(crate) full: Option>, pub(crate) len: usize, } @@ -58,6 +60,11 @@ pub(crate) struct RequestWorkload { pub(crate) struct RequestMetrics { pub(crate) ttft_ms: DurationStats, pub(crate) first_decode_step_ms: Option, + /// Per-request decode elapsed time divided by generated decode tokens. + /// This complements raw token-event intervals for speculative decoding, + /// where accepted spans can be emitted as a burst from one verify round. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub(crate) effective_tpot_ms: Option, pub(crate) steady_tpot_ms: Option, pub(crate) e2e_ms: DurationStats, pub(crate) generated_tokens: CountStats, @@ -72,6 +79,8 @@ pub(crate) struct RequestIterationTiming { pub(crate) index: usize, pub(crate) ttft_ms: f64, pub(crate) first_decode_step_ms: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) effective_tpot_ms: Option, pub(crate) steady_tpot_ms: Option, pub(crate) e2e_ms: f64, pub(crate) generated_tokens: usize, @@ -204,6 +213,8 @@ pub(crate) struct MatrixCell { pub(crate) ttft_ms: DurationStats, pub(crate) e2e_ms: DurationStats, pub(crate) first_decode_step_ms: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) effective_tpot_ms: Option, pub(crate) steady_tpot_ms: Option, pub(crate) generated_tokens: CountStats, pub(crate) request_tok_s: Option, diff --git a/openinfer-server/src/bin/bench_serving/runners.rs b/openinfer-server/src/bin/bench_serving/runners.rs index 358c64cd..877fa963 100644 --- a/openinfer-server/src/bin/bench_serving/runners.rs +++ b/openinfer-server/src/bin/bench_serving/runners.rs @@ -91,6 +91,13 @@ pub(crate) fn build_request_metrics(timings: &[GenTimings]) -> RequestMetrics { .iter() .flat_map(|t| t.tbt.iter().skip(1).copied()) .collect(); + let effective_tpot: Vec = timings + .iter() + .filter_map(|timing| { + (timing.decode_tokens_for_rate > 0) + .then(|| timing.decode_time_for_rate / timing.decode_tokens_for_rate as u32) + }) + .collect(); let generated: Vec = timings.iter().map(|t| t.emitted_tokens).collect(); let generated_token_traces: Vec = timings .iter() @@ -105,6 +112,8 @@ pub(crate) fn build_request_metrics(timings: &[GenTimings]) -> RequestMetrics { RequestMetrics { ttft_ms: summarize_durations(&ttfts), first_decode_step_ms: (!first_steps.is_empty()).then(|| summarize_durations(&first_steps)), + effective_tpot_ms: (!effective_tpot.is_empty()) + .then(|| summarize_durations(&effective_tpot)), steady_tpot_ms: (!steady.is_empty()).then(|| summarize_durations(&steady)), e2e_ms: summarize_durations(&e2e), generated_tokens: summarize_counts(&generated), @@ -124,6 +133,9 @@ pub(crate) fn build_request_iterations(timings: &[GenTimings]) -> Vec 0).then(|| { + dur_ms(timing.decode_time_for_rate) / timing.decode_tokens_for_rate as f64 + }), steady_tpot_ms: (!steady.is_empty()).then(|| summarize_durations(&steady)), e2e_ms: dur_ms(timing.total), generated_tokens: timing.emitted_tokens, @@ -258,6 +270,7 @@ pub(crate) fn bench_matrix( ttft_ms: metrics.ttft_ms, e2e_ms: metrics.e2e_ms, first_decode_step_ms: metrics.first_decode_step_ms, + effective_tpot_ms: metrics.effective_tpot_ms, steady_tpot_ms: metrics.steady_tpot_ms, generated_tokens: metrics.generated_tokens, request_tok_s: metrics.request_tok_s, @@ -378,6 +391,14 @@ pub(crate) fn render_text(report: &BenchReport) -> String { .into_iter() .map(|stats| ("first_decode_step_ms".to_string(), stats)), ) + .chain( + report + .metrics + .effective_tpot_ms + .clone() + .into_iter() + .map(|stats| ("effective_tpot_ms".to_string(), stats)), + ) .chain( report .metrics diff --git a/openinfer-server/src/config.rs b/openinfer-server/src/config.rs index 25dc9440..7af73b80 100644 --- a/openinfer-server/src/config.rs +++ b/openinfer-server/src/config.rs @@ -90,9 +90,9 @@ pub(crate) struct Args { #[arg(long, default_value_t = false)] pub no_prefix_cache: bool, - /// Enable Qwen3 DFlash speculative decoding with this drafter model path. - /// Single-GPU greedy only; incompatible with --enable-lora and --kv-offload, - /// and forces the prefix cache off (it needs clean target hidden states). + /// Enable Qwen3/Qwen3.5 DFlash speculative decoding with this drafter model path. + /// Single-GPU greedy only; incompatible with --enable-lora, --kv-offload, + /// tensor parallelism, and decode overlap. #[arg(long = "dflash-draft-model-path")] pub dflash_draft_model_path: Option, @@ -149,7 +149,7 @@ impl From for EpBackend { /// CLI selector for prefill/decode overlap. Mapped to /// [`openinfer_qwen3::DecodeOverlap`] together with `--decode-sm-pct`. -#[derive(Clone, Copy, Debug, ValueEnum)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, ValueEnum)] pub(crate) enum CliDecodeOverlap { /// One stream; prefill and decode serialize. Off, @@ -183,9 +183,16 @@ impl Args { let is_qwen3 = matches!(model_type, ModelType::Qwen3); #[cfg(not(feature = "qwen3"))] let is_qwen3 = false; + #[cfg(feature = "qwen35-4b")] + let is_qwen35 = matches!(model_type, ModelType::Qwen35); + #[cfg(not(feature = "qwen35-4b"))] + let is_qwen35 = false; if self.enable_lora && !is_qwen3 { bail!("--enable-lora is currently supported only for Qwen3"); } + if self.dflash_draft_model_path.is_some() && !(is_qwen3 || is_qwen35) { + bail!("--dflash-draft-model-path is currently supported only for Qwen3/Qwen3.5"); + } if self.batch_invariant && !is_qwen3 { bail!("--batch-invariant is currently supported only for Qwen3"); } diff --git a/openinfer-server/src/main.rs b/openinfer-server/src/main.rs index 64becf61..51b9aaf7 100644 --- a/openinfer-server/src/main.rs +++ b/openinfer-server/src/main.rs @@ -109,15 +109,19 @@ async fn main() -> anyhow::Result<()> { // defaults, capability constraints, cross-arg validation). The server only // picks the crate by detected model type and forwards the relevant CLI knobs. fn load_engine(args: &Args, model_type: ModelType) -> anyhow::Result { - // Only Qwen3 wires the DFlash drafter; fail loud rather than silently - // ignoring the flag for another model line. + // Only model lines with a wired drafter may accept DFlash; fail loud rather + // than silently ignoring the flag for another model line. #[cfg(feature = "qwen3")] let is_qwen3 = matches!(model_type, ModelType::Qwen3); #[cfg(not(feature = "qwen3"))] let is_qwen3 = false; + #[cfg(feature = "qwen35-4b")] + let is_qwen35 = matches!(model_type, ModelType::Qwen35); + #[cfg(not(feature = "qwen35-4b"))] + let is_qwen35 = false; anyhow::ensure!( - args.dflash_draft_model_path.is_none() || is_qwen3, - "--dflash-draft-model-path is only supported for Qwen3 (got {model_type:?})" + args.dflash_draft_model_path.is_none() || is_qwen3 || is_qwen35, + "--dflash-draft-model-path is only supported for Qwen3/Qwen3.5 (got {model_type:?})" ); let handle = match model_type { #[cfg(feature = "deepseek-v4")] @@ -219,14 +223,39 @@ fn load_engine(args: &Args, model_type: ModelType) -> anyhow::Result openinfer_qwen35_4b::launch( - &args.model_path, - args.device_ordinal, - args.cuda_graph, - args.max_prefill_tokens - .unwrap_or(openinfer_qwen35_4b::DEFAULT_MAX_PREFILL_TOKENS), - ) - .context("failed to start Qwen3.5 engine")?, + ModelType::Qwen35 => { + let dflash_draft_model_path = match args.dflash_draft_model_path.clone() { + Some(path) => { + anyhow::ensure!( + !args.enable_lora, + "--dflash-draft-model-path is not supported with --enable-lora" + ); + anyhow::ensure!( + !args.kv_offload, + "--dflash-draft-model-path is not supported with --kv-offload" + ); + anyhow::ensure!( + args.tp_size == 1, + "--dflash-draft-model-path currently requires --tp-size=1" + ); + anyhow::ensure!( + matches!(args.decode_overlap, config::CliDecodeOverlap::Off), + "--dflash-draft-model-path is not supported with --decode-overlap" + ); + Some(path) + } + None => None, + }; + openinfer_qwen35_4b::launch( + &args.model_path, + args.device_ordinal, + args.cuda_graph, + args.max_prefill_tokens + .unwrap_or(openinfer_qwen35_4b::DEFAULT_MAX_PREFILL_TOKENS), + dflash_draft_model_path, + ) + .context("failed to start Qwen3.5 engine")? + } }; Ok(handle)