From 0723f0be8d5437cb5cb3eb093c656330704c7353 Mon Sep 17 00:00:00 2001 From: cl Date: Thu, 18 Jun 2026 02:25:49 +0800 Subject: [PATCH] fix(scheduler): contain worker fatal errors Surface worker-domain failures as typed execution errors instead of letting scheduler paths wedge on closed channels. Qwen3 rank workers now catch panics, report a fatal WorkerPanic once, and exit; the scheduler marks the shared engine health unhealthy, fails bound work, and rejects later submissions explicitly. Keep recoverable step failures request-local where the worker domain remains trustworthy, and expose unhealthy engine state through frontend /health as HTTP 503. Also route obvious worker hot-path GEMM launch failures through Result propagation instead of unchecked panic wrappers. This is intentionally containment, not full fail-safe recovery. Follow-up work should replace remaining state-machine anyhow::Result usage with layered typed errors and design retry/reschedule semantics for failed execution domains. --- Cargo.lock | 1 + docs/index.md | 3 +- docs/subsystems/scheduler/scheduler.md | 4 +- .../scheduler/worker-fatal-containment.md | 212 ++++++++++++ openinfer-engine/Cargo.toml | 1 + openinfer-engine/src/engine.rs | 167 ++++++++- openinfer-engine/src/engine/error.rs | 68 ++++ openinfer-qwen3-4b/src/batch_decode.rs | 16 +- openinfer-qwen3-4b/src/batch_decode_dag.rs | 35 +- openinfer-qwen3-4b/src/executor.rs | 269 +++++++++++---- openinfer-qwen3-4b/src/lora.rs | 4 +- openinfer-qwen3-4b/src/prefill.rs | 29 +- openinfer-qwen3-4b/src/scheduler.rs | 317 ++++++++++++++++-- openinfer-qwen3-4b/src/scheduler/plan.rs | 5 +- openinfer-qwen3-4b/src/scheduler/resolve.rs | 113 ++++--- openinfer-qwen3-4b/src/unified_forward.rs | 29 +- openinfer-vllm-frontend/src/health.rs | 95 ++++++ openinfer-vllm-frontend/src/lib.rs | 12 +- 18 files changed, 1198 insertions(+), 182 deletions(-) create mode 100644 docs/subsystems/scheduler/worker-fatal-containment.md create mode 100644 openinfer-engine/src/engine/error.rs create mode 100644 openinfer-vllm-frontend/src/health.rs diff --git a/Cargo.lock b/Cargo.lock index 83b470e2..495ed51c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3797,6 +3797,7 @@ dependencies = [ name = "openinfer-engine" version = "0.1.0" dependencies = [ + "thiserror 2.0.18", "tokio", ] diff --git a/docs/index.md b/docs/index.md index 076347c5..1673b165 100644 --- a/docs/index.md +++ b/docs/index.md @@ -101,7 +101,8 @@ Organized by domain (model line / subsystem / playbook / lesson) instead of by l | Path | TL;DR | | --- | --- | -| `subsystems/scheduler/scheduler.md` | Single dedicated thread owns GPU; FCFS prefill-priority, paged KV, bucket CUDA Graphs, unified forward for mixed prefill+decode. Qwen3-4B at QPS=2 is within 2% of vLLM throughput while winning TTFT (-16%), TPOT (-3%), and latency stability. Open: ITL p99 tail, Qwen3.5 full-paged prefill, and high-concurrency wedge triage. | +| `subsystems/scheduler/scheduler.md` | Single dedicated thread owns GPU; FCFS prefill-priority, paged KV, bucket CUDA Graphs, unified forward for mixed prefill+decode. Qwen3-4B at QPS=2 is within 2% of vLLM throughput while winning TTFT (-16%), TPOT (-3%), and latency stability. Open: ITL p99 tail, Qwen3.5 full-paged prefill, and the Qwen3 high-concurrency cudarc copy root crash (containment now lives in `worker-fatal-containment.md`). | +| `subsystems/scheduler/worker-fatal-containment.md` | Worker-fatal containment: shared typed `ExecutionError` exposes `recovery()`, `EngineHealth` drives `/health` plus admission gating, and Qwen3 catches worker panics / rejects post-fatal work explicitly instead of silently wedging. | | `subsystems/scheduler/output-dispatch.md` | GPU bubble study + token-dispatch redesign (**landed 2026-06**). Single-thread CPU↔GPU(sync) alternation idles the GPU through scheduling; bubble ≈3µs×batch (bs=128 → ~380µs, 2% of an 18ms step on 5070 Ti), dominated by N per-request `token_tx.send` wakeups. Fix shipped: `token_tx` is a `TokenSink` drop-in over one request-tagged channel + one bridge demux loop (N→1 wakeups/tasks/ZMQ msgs); cancellation rides an `Arc` flag, not a separate channel. Bubble target ~150µs (exec_cpu floor). Trigger: fast GPUs (→10–15%) or N≫128. | | `subsystems/scheduler/qwen-batched-sampling.md` | Issue #284 record: Qwen3/Qwen3.5 mixed greedy/non-greedy token selection compacts non-greedy rows into one batched FlashInfer sampling call per step, with greedy rows staying on indexed batched argmax. | diff --git a/docs/subsystems/scheduler/scheduler.md b/docs/subsystems/scheduler/scheduler.md index 2b8312b3..d8c4e591 100644 --- a/docs/subsystems/scheduler/scheduler.md +++ b/docs/subsystems/scheduler/scheduler.md @@ -2,7 +2,7 @@ > **TL;DR:** Single dedicated thread owns all GPU resources. Continuous batching with FCFS prefill-priority, paged KV cache, bucket CUDA Graphs for batch decode, and a unified forward pass when prefill and decode coexist. On Qwen3-4B (varied-length Poisson QPS=2, RTX 5070 Ti) within 2% of vLLM throughput while winning TTFT (−16%), TPOT (−3%), and latency stability across the board. Remaining gap is ITL p99 tail from prefill stalls. > -> **Last touched:** 2026-05. +> **Last touched:** 2026-06. ## Why this shape @@ -95,7 +95,7 @@ Related: FlashInfer's f32 fused RoPE rounds differently from a precomputed bf16 - **ITL p99 tail (291 vs vLLM 211ms).** Large prefills block in-flight decode. Chunked prefill would fix it. Low priority — varied-length workloads break the waves naturally, and fixed-length ITL p99 already beats vLLM. - **9 failures at QPS=2.** Needs root-cause — likely KV pressure or empty-prompt rejection from the random dataset. -- **Worker-panic wedge at high concurrency (RTX 5070 Ti, verified 2026-06).** Under sustained high load (c≈128–192, mixed prefill+decode), the GPU worker thread `qwen3-tp-rank-0` panics in a cudarc copy: `assertion failed: dst.len() >= src.len()` (`cudarc .../driver/safe/core.rs:1607`) — a pre-allocated buffer is too small for some batch/shape the step produces. Two distinct faults compound: (1) the undersized buffer (proximate crash); (2) **blast radius + no recovery** — `fail_touched_requests` clears the *entire* active batch on any step `Err`, the worker thread is never restarted, so every subsequent step returns "worker step channel closed" and **all** in-flight and future requests fail. `/health` still returns 200 (separate layer), so it reads as a silent permanent wedge until process restart. Not KV admission: a 160-request single burst defers cleanly with 0 errors, so full-lifetime admission works. Independent of the output-dispatch change (the panic is in the executor copy path; `main` fails identically). Next: `RUST_BACKTRACE=1` to capture the call site — prime suspects are the prefill/unified or batched step-tail logits/sampling readback (not the decode token H2D, which is sized to the 256 bucket). +- **Worker-panic containment vs root crash.** The high-concurrency Qwen3 worker panic root cause is still the undersized cudarc copy (`dst.len() >= src.len()`) and needs a backtrace/shape fix in the executor path. Containment is now separate and landed in `worker-fatal-containment.md`: recoverable execution errors stay local, worker-channel death marks the execution domain unhealthy, `/health` reports 503 for unhealthy, and post-fatal submissions are rejected explicitly instead of silently wedging on a dead worker. - **Batched Qwen sampling regression guard.** #284 removes the Qwen per-row sampling path: greedy rows use indexed batched argmax, and non-greedy rows compact into one FlashInfer batched sampling call per step. Keep the release HF/nsys/TPOT gate when sampling params or decode batching change. - **Qwen3.5 partial paged migration.** Decode is fully paged via the scheduler. Prefill still scatters from contiguous HND staging into paged KV before attention. Migration mirrors Qwen3's step 2 with HD256 + partial RoPE (rotary_dim=64 of head_dim=256) wrinkles. diff --git a/docs/subsystems/scheduler/worker-fatal-containment.md b/docs/subsystems/scheduler/worker-fatal-containment.md new file mode 100644 index 00000000..ff690d57 --- /dev/null +++ b/docs/subsystems/scheduler/worker-fatal-containment.md @@ -0,0 +1,212 @@ +# Worker Fatal Containment + +> **TL;DR:** Worker-fatal containment landed at the shared engine boundary: typed `ExecutionError` variants describe the cause and expose `recovery()` as a policy property, `EngineHealth` drives `/health` and admission gating, Qwen3 catches worker panics as fatal, preserves recoverable recovery, and rejects post-fatal work explicitly instead of wedging. Follow-up issues are needed for fully typed model errors and retry/reschedule after domain failure. +> +> **Last touched:** 2026-06 + +## Preparation + +- **Read**: + - `docs/index.md` - routed the task to Qwen3 plus scheduler/frontend/engine boundaries. + - `docs/subsystems/scheduler/scheduler.md` - current scheduler design and the high-concurrency worker-panic wedge note; the containment fault is that worker death is not promoted to engine readiness/fatal state. + - `docs/models/qwen3/model-crate.md` - Qwen3 owns worker threads, executor, scheduler, and tests; root/frontend should stay on `EngineHandle`. + - `openinfer-qwen3-4b/src/executor.rs` - `RankWorker` owns `qwen3-tp-rank-*` threads; panic currently closes the worker response/channel and is surfaced as ordinary execution errors. + - `openinfer-qwen3-4b/src/scheduler.rs` - `execute_plan` errors call `fail_touched_requests`, which sends request errors and clears active state but does not mark engine/worker fatal. + - `openinfer-engine/src/engine.rs` - `EngineHandle` exposes submission/control/capacity but no worker health or fatal state. + - `openinfer-vllm-frontend/src/lib.rs` and `openinfer-sim/tests/frontend_e2e.rs` - frontend readiness is currently tied to server reachability/engine load; health tests only require a successful `/health` response. +- **Relevant history**: + - `docs/subsystems/scheduler/scheduler.md` - records the same containment shape: worker panic causes channel-closed errors and `/health` remains green. +- **Plan**: + 1. Add a deterministic worker-panic test hook under `#[cfg(test)]` in the Qwen3 executor path, so a unit/integration test can force the worker thread to panic without needing the intermittent cudarc shape. + 2. Split execution failures into recovery tiers at the shared engine/runtime boundary instead of inside Qwen3 only. Request-local errors should fail only that request. Step-local recoverable errors should fail only touched requests and preserve unrelated active/deferred long-running work. Execution-domain-fatal errors include worker panic, worker response channel closed, worker command channel closed, or any CUDA-worker state that makes future execution untrustworthy. + 3. Write a failing regression test that starts the real Qwen3 scheduler with a fake/test executor or a controlled Qwen3 worker path, triggers the panic, observes the channel-closed state, and proves future requests currently keep being admitted/fail incorrectly. + 4. Add engine-level execution error and fatal/readiness primitives to `openinfer-engine` that can be reused by every model crate and shared with `EngineHandle` clones. + 5. Teach the Qwen3 scheduler as the first adopter: handle recoverable errors with bounded blast radius and preserve unaffected long-running requests where state remains trustworthy, but handle execution-domain-fatal failures by failing work bound to the dead domain, stopping future admission to that domain, and making subsequent submissions receive a clear fatal error rather than entering the dead worker loop. + 6. Expose the fatal state to frontend readiness. Prefer an OpenInfer-owned `/health` override if the vLLM router extension allows it; otherwise keep the engine-side fatal API and add direct tests now, then route-level health as the next patch. + 7. Verify with targeted release tests: error-classification tests, engine fatal-state tests, Qwen3 worker-panic containment test, and frontend health behavior if route override lands. Use `--release` for Qwen3/CUDA-bound tests per repo convention. +- **Risks / open questions**: + - If vLLM's `/health` route cannot be overridden cleanly from `openinfer-vllm-frontend`, this patch may prove containment through `EngineHandle`/scheduler tests first and leave HTTP health wiring as a follow-up. + - Restarting a CUDA worker in-process is out of scope for the first containment patch unless the restart boundary can rebuild a clean CUDA/model/KV execution domain. The first safe target is best-effort request preservation for recoverable errors, and fail-closed only for state-unsafe worker failures. + - The deterministic panic hook must be test-only and must not add runtime overhead or user-triggerable behavior in release serving. + - Avoid stringly typed fatal detection if possible; prefer a typed error boundary from worker/executor to scheduler. The type should live in shared engine/runtime code, not as a Qwen3-only convention. + - Transparent retry/reschedule after an execution-domain fatal is a separate capability, not part of this containment patch. It needs a defined retry boundary: requests with no emitted tokens may be recomputable; streaming requests with emitted tokens need deterministic replay or explicit client-visible restart semantics; cross-domain migration needs KV/page ownership transfer or cold recompute. + +## Execution Log + +### Step 1: Current error boundary +- Current Qwen3 executor/scheduler path uses `anyhow::Result` for all execution failures: + - `ModelExecutor::{execute_prefill, execute_decode, execute_unified}` returns `anyhow::Result`. + - `scheduler::plan::execute_plan` propagates that `Err`. + - scheduler loops call `fail_touched_requests` for every `Err` and then continue. +- Worker panic surfaces as channel closure: + - primary response drop maps to `primary worker dropped step response`. + - command-channel closure maps to `tensor-parallel worker step channel closed`. + - TP peer response drop maps to `tensor-parallel worker dropped`. +- Design adjustment from review: + - Use `thiserror` for typed executor errors. + - Treat worker channel/response closure and TP protocol violations as worker-fatal. + - Treat worker-returned step `Err` as recoverable for the first patch because the worker thread is still alive; this preserves long-running work where state is still trustworthy. + +### Step 2: Typed containment path +- Added shared engine/runtime primitives in `openinfer-engine`: + - `ExecutionError` typed variants in `openinfer-engine/src/engine/error.rs` plus `ExecutionResult`. + - `ExecutionRecovery` as a property returned by `ExecutionError::recovery()`; recoverability is not encoded as the error variant itself. + - `EngineReadiness::{Healthy, Degraded, Unhealthy}`. + - `EngineHealth`, shared across `EngineHandle` clones. +- Qwen3 is the first adopter: + - `ModelExecutor::{execute_prefill, execute_decode, execute_unified}` now returns `ExecutionResult`. + - Worker command/response channel closure maps to `DomainFatal`. + - Worker-returned execution errors map to `Recoverable`. + - Recoverable scheduler errors fail touched requests and continue. + - Domain-fatal scheduler errors mark engine unhealthy, fail active/deferred/loading/prefilling work, and keep the scheduler alive to reject future requests with a clear `TokenEvent::Error`. +- Tests run: + - `cargo test --release -p openinfer-engine --lib engine_ -- --nocapture` (health clone test passed; one test matched the filter). + - `cargo test --release -p openinfer-engine --lib execution_error_separates_recoverable_from_domain_fatal -- --nocapture` + - `cargo test --release -p openinfer-qwen3-4b --lib fatal_worker_error_marks_engine_unhealthy_and_rejects_future_work -- --nocapture` + - `cargo test --release -p openinfer-qwen3-4b --lib decode_error_drops_request_state_and_scheduler_recovers -- --nocapture` + +### Step 3: Frontend readiness surface +- Added `openinfer-vllm-frontend/src/health.rs`: + - Middleware intercepts `/health`. + - `Healthy` returns HTTP 200 with `{"status":"ok"}`. + - `Degraded` returns HTTP 200 with `{"status":"degraded","reason":...}`. + - `Unhealthy` returns HTTP 503 with `{"status":"unhealthy","reason":...}`. +- `serve_model_on_host_with_router_extension` now stores the loaded engine's shared `EngineHealth` in the middleware state before the bridge starts. +- Tests run: + - `cargo test --release -p openinfer-vllm-frontend health_guard --lib -- --nocapture` + - Re-ran the Qwen3 fatal containment test and engine health clone test after wiring the frontend. + - `cargo test --release -p openinfer-vllm-frontend --lib` + - `cargo test --release -p openinfer-sim --test frontend_e2e simulated_engine_serves_openai_completions_over_http -- --nocapture` + +### Step 4: Broader verification +- `cargo fmt --check` initially reported formatting diffs; ran `cargo fmt`. +- `cargo fmt --check` +- `cargo test --release -p openinfer-engine --lib` +- `cargo test --release -p openinfer-qwen3-4b --lib scheduler -- --nocapture` +- `cargo test --release -p openinfer-qwen3-4b --lib` + +### Step 5: Error typing cleanup +- Moved shared execution error definitions out of `engine.rs` into `openinfer-engine/src/engine/error.rs`. +- Replaced policy-shaped variants (`Recoverable` / `DomainFatal`) with cause-shaped variants: + - `StepFailed` + - `WorkerCommandChannelClosed` + - `WorkerResponseDropped` + - `WorkerPanic` + - `UnexpectedWorkerResponse` +- `ExecutionError::recovery()` is now the policy property that returns `ExecutionRecovery::{Recoverable, DomainFatal}`. +- Qwen3 scheduler now branches on `e.recovery()` instead of matching a recoverable/fatal variant name. +- Tests/checks run after the cleanup: + - `cargo test --release -p openinfer-engine --lib` + - `cargo test --release -p openinfer-qwen3-4b --lib scheduler -- --nocapture` + - `cargo fmt --check` + - `git diff --check` + +### Step 6: Fail-safe infrastructure pass +- Added an `EngineHandle` admission gate: + - `submit` checks shared readiness before enqueueing work. + - unhealthy engines send a request-local `TokenEvent::Error` immediately and do not enqueue into the scheduler. + - LoRA control calls also reject before entering a dead engine. +- Moved Qwen3 worker step responses from `anyhow::Result` to `ExecutionResult`, so the worker boundary itself is typed. +- Wrapped Qwen3 worker step execution in `catch_unwind`: + - a real worker panic is converted to `ExecutionError::WorkerPanic`; + - the worker reports the fatal error once and exits; + - later dispatch observes `WorkerCommandChannelClosed`. +- Added a no-GPU worker lifecycle test that triggers an actual panic inside a worker thread and verifies it returns `DomainFatal` then exits. +- Tests/checks run after this pass: + - `cargo test --release -p openinfer-engine --lib` + - `cargo test --release -p openinfer-qwen3-4b --lib worker_panic_is_reported_as_domain_fatal_then_worker_exits -- --nocapture` + - `cargo test --release -p openinfer-qwen3-4b --lib scheduler -- --nocapture` + - `cargo fmt --check` + - `git diff --check` + +### Step 7: Worker hot-path panic audit +- Audited Qwen3 worker step paths for `panic!` / `assert!` / `unwrap` / `expect` that could turn recoverable execution failures into worker death. +- Converted the obvious worker-hot-path GEMM launch points from unchecked wrappers to checked `Result` propagation: + - Qwen3 prefill, decode DAG, unified forward, and LoRA projection deltas now use checked GEMM calls where the checked kernel API already exists. + - These failures now flow back through the worker `ExecutionResult` path instead of panicking first. +- Converted state/protocol boundary panics into typed execution errors: + - zero-token prefill chunks and missing local `RequestKv` now return step errors; + - worker results for unknown request ids or mismatched result sets return `UnexpectedWorkerResponse`, which is domain-fatal because applying those results would corrupt scheduler state. +- Remaining panic surface: + - several low-level kernel/shape wrappers still use assertions or unchecked `()` APIs, especially elementwise/attention helpers. Those need a separate typed-error cleanup instead of opportunistic conversion in this containment patch. +- Tests/checks run after this audit: + - `cargo check --release -p openinfer-qwen3-4b --lib` + - `cargo test --release -p openinfer-engine --lib` + - `cargo test --release -p openinfer-qwen3-4b --lib worker_panic_is_reported_as_domain_fatal_then_worker_exits -- --nocapture` + - `cargo test --release -p openinfer-qwen3-4b --lib scheduler -- --nocapture` + - `cargo test --release -p openinfer-qwen3-4b --lib` + - `cargo test --release -p openinfer-vllm-frontend --lib` + - `cargo fmt --check` + - `git diff --check` + +### Step 8: Complexity review cleanup +- Applied the over-engineering review cuts that preserved behavior: + - `EngineHealth` is now a one-way fatal latch backed by `OnceLock` instead of a mutable `Mutex`. + - Removed the unused `Degraded` readiness state and frontend degraded `/health` response. + - Replaced the frontend `HealthProbe` newtype with `Arc>` directly. + - Removed scheduler-local `fatal_reason`; the scheduler now reads the shared `EngineHealth` fatal latch. + - Collapsed duplicate execute/resolve error handling branches in both scheduler loops. +- Tests/checks run after this cleanup: + - `cargo check --release -p openinfer-engine --lib` + - `cargo check --release -p openinfer-qwen3-4b --lib` + - `cargo test --release -p openinfer-engine --lib` + - `cargo test --release -p openinfer-vllm-frontend --lib` + - `cargo test --release -p openinfer-qwen3-4b --lib scheduler -- --nocapture` + - `cargo test --release -p openinfer-qwen3-4b --lib` + - `cargo fmt --check` + - `git diff --check` + +## Follow-up Issues + +### Replace state-machine `anyhow::Result` with typed model/runtime errors +- **Problem**: `anyhow::Result` is still present inside Qwen3 executor internals and some model/kernel glue. It is acceptable for tests, CLI/startup wiring, and outer diagnostics, but it makes scheduler/worker state-machine contracts stale quickly because callers cannot match causes without string inspection. +- **Reference shape**: + - Databend's current `databend_common_exception` uses a workspace-level `ErrorCode` with stable numeric codes, names, display text, backtrace/context frames, and explicit `map_err_to_code` conversion at foreign-error boundaries. + - Databend's own error-handling RFC calls out that a single flat error layer is not enough for high-level reasoning and proposes layered error types plus explicit context conversion (`change_context`) instead of implicit `From` propagation. + - OpenInfer should copy the discipline, not the exact shape: stable shared boundary errors for engine/frontend contracts, smaller domain errors inside model/runtime/kernel crates, explicit conversion between layers, and typed metadata for recovery/admission decisions. +- **Scope**: + - introduce model/runtime error enums for Qwen3 execution internals and convert them into shared `ExecutionError` only at the engine boundary; + - remove `anyhow::Result` from scheduler plan/resolve paths and worker-step hot-path helpers where recovery policy matters; + - keep recoverability as an error property (`recovery()`), not as the variant shape; + - convert kernel launch/shape failures that already return `Result` into typed variants instead of wrapping everything as generic `StepFailed`; + - attach operation context as typed fields or explicit frames (`op`, `request_id`, `rank`, `domain`, `layer`, `kernel`) instead of concatenating strings early. +- **Acceptance**: + - no `anyhow::Result` crosses a worker/executor/scheduler boundary where the caller must decide recover/reject/fatal; + - tests cover at least one typed recoverable model-step error and one typed domain-fatal protocol/state error; + - remaining `anyhow` uses are intentionally limited to startup, CLI/test glue, or leaf code that is immediately mapped into a typed error; + - user-facing errors can be rendered from the typed error without losing machine-readable code/category/recovery policy. + +### Retry/reschedule work after execution-domain failure +- **Problem**: current containment is fail-closed. It prevents wedging and preserves unrelated work for recoverable errors, but it does not retry or migrate requests bound to a failed execution domain. +- **Scope**: + - define domain identity for TP groups and future DP lanes; + - specify which requests are retryable: no emitted tokens can be recomputed, streamed requests need deterministic replay or explicit client-visible restart semantics; + - define whether KV can be migrated, cold-recomputed, or must be discarded; + - add scheduler APIs for requeueing retryable work and degrading capacity when only one DP lane dies. +- **Acceptance**: + - a domain-fatal failure can reschedule eligible work without double-emitting tokens; + - unretryable work receives a clear terminal error; + - TP group failure remains fail-closed unless a clean group restart boundary exists; + - DP-lane failure can degrade capacity while healthy lanes keep serving when the model line supports DP isolation. + +## Debrief + +- **Outcome**: + - Added shared `openinfer-engine` primitives for typed execution errors, recovery classification, and engine readiness. + - Added a frontend-facing admission gate so unhealthy engines do not continue queueing work. + - Added worker panic capture in the Qwen3 rank worker loop, with typed fatal reporting and worker exit. + - Qwen3 scheduler now treats recoverable execution errors as request/step-local and continues serving. + - Qwen3 scheduler now treats worker-domain fatal errors as engine-unhealthy, fails bound work, and keeps the scheduler alive only to reject future submissions with a clear error. + - Frontend `/health` now reflects `EngineHealth`: healthy 200, degraded 200, unhealthy 503. +- **Pitfalls encountered**: + - The first plan was too Qwen3-local. The error/readiness concepts belong in shared engine/runtime code; Qwen3 is only the first adopter. + - `anyhow::Result` was too weak at the scheduler boundary because the caller needed to classify recovery policy. Lower-level code still uses `anyhow` internally, but the execution boundary is now typed. + - Encoding recoverable/fatal directly as error variants was also too weak: it hid the actual cause. The shared error now uses cause-shaped variants and exposes recoverability through `recovery()`. + - A domain-fatal scheduler should not exit immediately; staying alive lets it reject future submissions explicitly instead of turning them into generic channel-closed failures. +- **Lessons learned**: + - Worker panic containment is not transparent recovery. Requests already streaming from a failed domain still need an explicit error unless a future retry/reschedule design defines deterministic replay or KV migration. + - Shared readiness should be model-agnostic; `/health` should only consume aggregate engine state, not know TP/DP/Qwen3 details. +- **Follow-ups**: + - File and implement "Replace state-machine `anyhow::Result` with typed model/runtime errors" from the follow-up issue draft above. + - File and implement "Retry/reschedule work after execution-domain failure" from the follow-up issue draft above. + - Adopt `ExecutionError`/`EngineHealth` in other model schedulers so containment semantics are consistent beyond Qwen3. diff --git a/openinfer-engine/Cargo.toml b/openinfer-engine/Cargo.toml index f1d6664f..d91efcdf 100644 --- a/openinfer-engine/Cargo.toml +++ b/openinfer-engine/Cargo.toml @@ -5,6 +5,7 @@ version = "0.1.0" edition = "2024" [dependencies] +thiserror = { workspace = true } tokio = { workspace = true, features = ["sync"] } [dev-dependencies] diff --git a/openinfer-engine/src/engine.rs b/openinfer-engine/src/engine.rs index 5867ad53..d683ce59 100644 --- a/openinfer-engine/src/engine.rs +++ b/openinfer-engine/src/engine.rs @@ -3,7 +3,7 @@ use std::{ fmt, path::PathBuf, sync::{ - Arc, + Arc, OnceLock, atomic::{AtomicBool, Ordering}, }, thread::{self, JoinHandle}, @@ -14,6 +14,9 @@ use tokio::sync::{mpsc, oneshot}; use crate::parallel::ParallelConfig; use crate::sampler::SamplingParams; +pub mod error; +pub use error::{ExecutionError, ExecutionRecovery, ExecutionResult}; + #[derive(Clone, Debug)] pub struct EngineLoadOptions { pub enable_cuda_graph: bool, @@ -135,6 +138,57 @@ impl Error for EngineControlError {} pub type EngineControlResult = std::result::Result; +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum EngineReadiness { + Healthy, + Unhealthy { reason: String }, +} + +impl EngineReadiness { + pub fn unservable_reason(&self) -> Option<&str> { + match self { + Self::Healthy => None, + Self::Unhealthy { reason } => Some(reason), + } + } +} + +#[derive(Clone)] +pub struct EngineHealth { + unhealthy_reason: Arc>, +} + +impl Default for EngineHealth { + fn default() -> Self { + Self::new() + } +} + +impl EngineHealth { + pub fn new() -> Self { + Self { + unhealthy_reason: Arc::new(OnceLock::new()), + } + } + + pub fn readiness(&self) -> EngineReadiness { + match self.unhealthy_reason() { + Some(reason) => EngineReadiness::Unhealthy { + reason: reason.to_string(), + }, + None => EngineReadiness::Healthy, + } + } + + pub fn mark_unhealthy(&self, reason: impl Into) { + let _ = self.unhealthy_reason.set(reason.into()); + } + + pub fn unhealthy_reason(&self) -> Option<&str> { + self.unhealthy_reason.get().map(String::as_str) + } +} + pub enum TokenEvent { Scheduled { queued_at_unix_s: f64, @@ -291,6 +345,7 @@ pub struct EngineHandle { /// KV pool capacity in blocks + block size, or `None` if the engine did not /// report it. See [`KvCapacity`]. kv_capacity: Option, + health: EngineHealth, } struct EngineInner { @@ -340,9 +395,16 @@ impl EngineHandle { }), servable_len: None, kv_capacity: None, + health: EngineHealth::new(), } } + #[must_use] + pub fn with_health(mut self, health: EngineHealth) -> Self { + self.health = health; + self + } + #[must_use] pub fn with_servable_len(mut self, servable_len: u32) -> Self { self.servable_len = Some(servable_len); @@ -366,11 +428,28 @@ impl EngineHandle { self.kv_capacity } + pub fn readiness(&self) -> EngineReadiness { + self.health.readiness() + } + + pub fn health(&self) -> EngineHealth { + self.health.clone() + } + #[allow(clippy::result_large_err)] pub fn submit( &self, req: GenerateRequest, ) -> std::result::Result<(), mpsc::error::SendError> { + let readiness = self.readiness(); + if let Some(reason) = readiness.unservable_reason() { + let _ = req.token_tx.send(TokenEvent::Error { + message: reason.to_string(), + prompt_tokens: req.prompt_tokens.len(), + completion_tokens: 0, + }); + return Ok(()); + } match self.inner.submit_tx.as_ref() { Some(submit_tx) => submit_tx.send(req), None => match self.inner.command_tx.as_ref() { @@ -393,6 +472,9 @@ impl EngineHandle { &self, request: LoadLoraAdapterRequest, ) -> EngineControlResult<()> { + if let Some(reason) = self.readiness().unservable_reason() { + return Err(EngineControlError::OperationFailed(reason.to_string())); + } match self.inner.command_tx.as_ref() { Some(command_tx) => { let (response_tx, response_rx) = oneshot::channel(); @@ -417,6 +499,9 @@ impl EngineHandle { } pub async fn list_lora_adapters(&self) -> EngineControlResult> { + if let Some(reason) = self.readiness().unservable_reason() { + return Err(EngineControlError::OperationFailed(reason.to_string())); + } match self.inner.command_tx.as_ref() { Some(command_tx) => { let (response_tx, response_rx) = oneshot::channel(); @@ -441,6 +526,9 @@ impl EngineHandle { &self, request: UnloadLoraAdapterRequest, ) -> EngineControlResult<()> { + if let Some(reason) = self.readiness().unservable_reason() { + return Err(EngineControlError::OperationFailed(reason.to_string())); + } match self.inner.command_tx.as_ref() { Some(command_tx) => { let (response_tx, response_rx) = oneshot::channel(); @@ -516,6 +604,83 @@ mod tests { assert!(handle.supports_lora_control()); } + #[test] + fn engine_health_is_shared_across_handle_clones() { + let (submit_tx, _submit_rx) = mpsc::unbounded_channel::(); + let handle = EngineHandle::new(submit_tx); + let clone = handle.clone(); + + assert_eq!(handle.readiness(), EngineReadiness::Healthy); + clone.health().mark_unhealthy("worker died"); + + assert_eq!( + handle.readiness(), + EngineReadiness::Unhealthy { + reason: "worker died".to_string() + } + ); + assert_eq!(handle.readiness().unservable_reason(), Some("worker died")); + } + + #[test] + fn execution_error_separates_recoverable_from_domain_fatal() { + let recoverable = ExecutionError::step_failed("request cleanup failed"); + assert_eq!(recoverable.recovery(), ExecutionRecovery::Recoverable); + assert!(!recoverable.is_domain_fatal()); + assert_eq!(recoverable.to_string(), "request cleanup failed"); + + let fatal = ExecutionError::worker_command_channel_closed("decode"); + assert_eq!(fatal.recovery(), ExecutionRecovery::DomainFatal); + assert!(fatal.is_domain_fatal()); + assert_eq!( + fatal.to_string(), + "worker command channel closed during decode" + ); + } + + #[test] + fn unhealthy_engine_rejects_submit_without_enqueueing() { + let (submit_tx, mut submit_rx) = mpsc::unbounded_channel::(); + let handle = EngineHandle::new(submit_tx); + handle.health().mark_unhealthy("worker died"); + + let (token_tx, mut token_rx) = TokenSink::standalone(); + let req = GenerateRequest { + request_id: None, + queued_at_unix_s: None, + prompt_tokens: vec![1, 2, 3], + params: SamplingParams::default(), + max_tokens: 1, + lora_adapter: None, + token_tx, + logprobs: 0, + echo: false, + }; + + handle + .submit(req) + .expect("submit should be explicitly rejected"); + assert!( + submit_rx.try_recv().is_err(), + "unhealthy submit must not enqueue work" + ); + match token_rx.blocking_recv() { + Some(( + _, + TokenEvent::Error { + message, + prompt_tokens, + completion_tokens, + }, + )) => { + assert_eq!(message, "worker died"); + assert_eq!(prompt_tokens, 3); + assert_eq!(completion_tokens, 0); + } + _ => panic!("expected explicit submit rejection"), + } + } + #[tokio::test] async fn load_lora_adapter_sends_control_command() { let (command_tx, mut command_rx) = mpsc::unbounded_channel::(); diff --git a/openinfer-engine/src/engine/error.rs b/openinfer-engine/src/engine/error.rs new file mode 100644 index 00000000..a236f364 --- /dev/null +++ b/openinfer-engine/src/engine/error.rs @@ -0,0 +1,68 @@ +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ExecutionRecovery { + Recoverable, + DomainFatal, +} + +#[derive(Debug, thiserror::Error, Eq, PartialEq)] +pub enum ExecutionError { + #[error("{message}")] + StepFailed { message: String }, + #[error("worker command channel closed during {op}")] + WorkerCommandChannelClosed { op: String }, + #[error("{worker} worker dropped {op} response")] + WorkerResponseDropped { worker: String, op: String }, + #[error("worker {worker} panicked: {message}")] + WorkerPanic { worker: String, message: String }, + #[error("unexpected worker response during {op}: {got}")] + UnexpectedWorkerResponse { op: String, got: String }, +} + +impl ExecutionError { + pub fn step_failed(message: impl Into) -> Self { + Self::StepFailed { + message: message.into(), + } + } + + pub fn worker_command_channel_closed(op: impl Into) -> Self { + Self::WorkerCommandChannelClosed { op: op.into() } + } + + pub fn worker_response_dropped(worker: impl Into, op: impl Into) -> Self { + Self::WorkerResponseDropped { + worker: worker.into(), + op: op.into(), + } + } + + pub fn worker_panic(worker: impl Into, message: impl Into) -> Self { + Self::WorkerPanic { + worker: worker.into(), + message: message.into(), + } + } + + pub fn unexpected_worker_response(op: impl Into, got: impl Into) -> Self { + Self::UnexpectedWorkerResponse { + op: op.into(), + got: got.into(), + } + } + + pub fn recovery(&self) -> ExecutionRecovery { + match self { + Self::StepFailed { .. } => ExecutionRecovery::Recoverable, + Self::WorkerCommandChannelClosed { .. } + | Self::WorkerResponseDropped { .. } + | Self::WorkerPanic { .. } + | Self::UnexpectedWorkerResponse { .. } => ExecutionRecovery::DomainFatal, + } + } + + pub fn is_domain_fatal(&self) -> bool { + self.recovery() == ExecutionRecovery::DomainFatal + } +} + +pub type ExecutionResult = std::result::Result; diff --git a/openinfer-qwen3-4b/src/batch_decode.rs b/openinfer-qwen3-4b/src/batch_decode.rs index de51aa0f..e977f9c8 100644 --- a/openinfer-qwen3-4b/src/batch_decode.rs +++ b/openinfer-qwen3-4b/src/batch_decode.rs @@ -193,7 +193,7 @@ impl Qwen3Model { self.output_projection(), &bufs.normed, &mut bufs.logits, - ); + )?; Ok(()) } @@ -220,7 +220,7 @@ impl Qwen3Model { q_dim, &bufs.normed, &mut bufs.q, - ); + )?; dag.gemm_rows::( dag_label!(format!("L{layer_idx}.attn.k_proj")), &layer.attention.qkv_proj, @@ -228,7 +228,7 @@ impl Qwen3Model { kv_dim, &bufs.normed, &mut bufs.k, - ); + )?; dag.gemm_rows::( dag_label!(format!("L{layer_idx}.attn.v_proj")), &layer.attention.qkv_proj, @@ -236,7 +236,7 @@ impl Qwen3Model { kv_dim, &bufs.normed, &mut bufs.v, - ); + )?; self.apply_decode_lora_projection_group3( layer_idx, LoraProjectionKind::Q, @@ -273,7 +273,7 @@ impl Qwen3Model { &layer.attention.o_proj, &bufs.attn_out, &mut bufs.attn_proj, - ); + )?; self.apply_decode_lora_projection( layer_idx, LoraProjectionKind::O, @@ -303,13 +303,13 @@ impl Qwen3Model { &layer.mlp.gate_up_proj, &bufs.normed, &mut bufs.gate_out, - ); + )?; dag.mlp_up_proj( dag_label!(format!("L{layer_idx}.mlp.up_proj")), &layer.mlp.gate_up_proj, &bufs.normed, &mut bufs.up_out, - ); + )?; self.apply_decode_lora_projection_group2( layer_idx, LoraProjectionKind::Gate, @@ -331,7 +331,7 @@ impl Qwen3Model { &layer.mlp.down_proj, &bufs.mlp_act, &mut bufs.mlp_out, - ); + )?; self.apply_decode_lora_projection( layer_idx, LoraProjectionKind::Down, diff --git a/openinfer-qwen3-4b/src/batch_decode_dag.rs b/openinfer-qwen3-4b/src/batch_decode_dag.rs index e3ea6f91..e1c76d34 100644 --- a/openinfer-qwen3-4b/src/batch_decode_dag.rs +++ b/openinfer-qwen3-4b/src/batch_decode_dag.rs @@ -134,7 +134,7 @@ impl<'a> BatchDecodeDag<'a> { rows: usize, x: &HiddenStates, out: &mut HiddenStates, - ) { + ) -> Result<()> { #[cfg(feature = "kernel-call-trace")] Self::record(gemm_rows_call::( label, @@ -144,7 +144,14 @@ impl<'a> BatchDecodeDag<'a> { row_offset, x.seq_len, )); - openinfer_kernels::ops::gemm_rows_into(&self.model.ctx, weight, row_offset, rows, x, out); + openinfer_kernels::ops::gemm_rows_into_checked( + &self.model.ctx, + weight, + row_offset, + rows, + x, + out, + ) } // `Out`/`In` label the kernel-call-trace record; unused without the feature. @@ -158,7 +165,7 @@ impl<'a> BatchDecodeDag<'a> { weight: &DeviceMatrix, x: &HiddenStates, out: &mut HiddenStates, - ) { + ) -> Result<()> { #[cfg(feature = "kernel-call-trace")] Self::record(gemm_call::( label, @@ -166,7 +173,7 @@ impl<'a> BatchDecodeDag<'a> { weight.cols, x.seq_len, )); - openinfer_kernels::ops::gemm_into(&self.model.ctx, weight, x, out); + openinfer_kernels::ops::gemm_into_checked(&self.model.ctx, weight, x, out) } pub(crate) fn qk_norm_rope( @@ -305,8 +312,8 @@ impl<'a> BatchDecodeDag<'a> { weight: &DeviceMatrix, x: &HiddenStates, out: &mut HiddenStates, - ) { - self.gemm::(label, weight, x, out); + ) -> Result<()> { + self.gemm::(label, weight, x, out) } pub(crate) fn mlp_gate_proj( @@ -315,8 +322,8 @@ impl<'a> BatchDecodeDag<'a> { weight: &DeviceMatrix, x: &HiddenStates, out: &mut HiddenStates, - ) { - self.gemm_rows::(label, weight, 0, out.hidden_dim, x, out); + ) -> Result<()> { + self.gemm_rows::(label, weight, 0, out.hidden_dim, x, out) } pub(crate) fn mlp_up_proj( @@ -325,8 +332,8 @@ impl<'a> BatchDecodeDag<'a> { weight: &DeviceMatrix, x: &HiddenStates, out: &mut HiddenStates, - ) { - self.gemm_rows::(label, weight, out.hidden_dim, out.hidden_dim, x, out); + ) -> Result<()> { + self.gemm_rows::(label, weight, out.hidden_dim, out.hidden_dim, x, out) } pub(crate) fn silu_mul_split( @@ -351,8 +358,8 @@ impl<'a> BatchDecodeDag<'a> { weight: &DeviceMatrix, x: &HiddenStates, out: &mut HiddenStates, - ) { - self.gemm::(label, weight, x, out); + ) -> Result<()> { + self.gemm::(label, weight, x, out) } pub(crate) fn lm_head( @@ -361,8 +368,8 @@ impl<'a> BatchDecodeDag<'a> { weight: &DeviceMatrix, x: &HiddenStates, out: &mut HiddenStates, - ) { - self.gemm::(label, weight, x, out); + ) -> Result<()> { + self.gemm::(label, weight, x, out) } #[cfg(feature = "kernel-call-trace")] diff --git a/openinfer-qwen3-4b/src/executor.rs b/openinfer-qwen3-4b/src/executor.rs index 12a34d4e..dbfd61d8 100644 --- a/openinfer-qwen3-4b/src/executor.rs +++ b/openinfer-qwen3-4b/src/executor.rs @@ -1,4 +1,6 @@ +use std::any::Any; use std::collections::{HashMap, HashSet}; +use std::panic::{self, AssertUnwindSafe}; use std::thread; use anyhow::Result; @@ -8,7 +10,9 @@ use crate::batch_decode_buffers::{BATCH_BUCKETS, BatchDecodeBuffers}; use crate::config::{Config, TensorParallelConfig}; use crate::weights::{ModelRuntimeConfig, Qwen3Model}; use crate::{Qwen3LoraOptions, Qwen3OffloadOptions}; -use openinfer_core::engine::{LoadLoraAdapterRequest, TokenLogprob, UnloadLoraAdapterRequest}; +use openinfer_core::engine::{ + ExecutionError, ExecutionResult, LoadLoraAdapterRequest, TokenLogprob, UnloadLoraAdapterRequest, +}; use openinfer_core::kv_pool::KvLayout; use openinfer_core::ops; use openinfer_core::sampler::SamplingParams; @@ -504,9 +508,9 @@ pub(crate) trait ModelExecutor: Send { fn is_stop_token(&self, token_id: u32) -> bool; fn drop_request(&mut self, request_id: RequestId) -> Result<()>; - fn execute_prefill(&mut self, plan: PrefillPlan<'_>) -> Result; - fn execute_decode(&mut self, plan: DecodePlan<'_>) -> Result; - fn execute_unified(&mut self, plan: UnifiedPlan<'_>) -> Result; + fn execute_prefill(&mut self, plan: PrefillPlan<'_>) -> ExecutionResult; + fn execute_decode(&mut self, plan: DecodePlan<'_>) -> ExecutionResult; + fn execute_unified(&mut self, plan: UnifiedPlan<'_>) -> ExecutionResult; fn load_lora_adapter(&mut self, request: &LoadLoraAdapterRequest) -> Result<()> { anyhow::bail!( @@ -829,15 +833,15 @@ impl Qwen3Executor { } pub fn execute_prefill(&mut self, plan: PrefillPlan<'_>) -> Result { - ::execute_prefill(self, plan) + ::execute_prefill(self, plan).map_err(anyhow::Error::new) } pub fn execute_decode(&mut self, plan: DecodePlan<'_>) -> Result { - ::execute_decode(self, plan) + ::execute_decode(self, plan).map_err(anyhow::Error::new) } pub fn execute_unified(&mut self, plan: UnifiedPlan<'_>) -> Result { - ::execute_unified(self, plan) + ::execute_unified(self, plan).map_err(anyhow::Error::new) } pub fn load_lora_adapter(&mut self, request: &LoadLoraAdapterRequest) -> Result<()> { @@ -987,9 +991,20 @@ impl Qwen3Executor { let rkv = self .request_kvs .get_mut(&req.request_id) - .expect("inserted above"); + .ok_or_else(|| anyhow::anyhow!("missing RequestKv for {:?}", req.request_id))?; req.chunk_start = rkv.kv_position(); - let remaining = req.prompt_tokens.len() - req.chunk_start; + let remaining = req + .prompt_tokens + .len() + .checked_sub(req.chunk_start) + .ok_or_else(|| { + anyhow::anyhow!( + "prefill position {} exceeds prompt length {} for {:?}", + req.chunk_start, + req.prompt_tokens.len(), + req.request_id + ) + })?; // Echo must produce all-position logits in a single forward, so it is // exempt from chunking (the scheduler never splits echo requests). req.chunk_tokens = if req.echo { @@ -997,7 +1012,7 @@ impl Qwen3Executor { } else { remaining.min(req.chunk_budget) }; - assert!( + anyhow::ensure!( req.chunk_tokens > 0, "zero-token prefill chunk for {:?} (budget {})", req.request_id, @@ -1010,16 +1025,21 @@ impl Qwen3Executor { /// Register a finished prefill step on the request's KV: the final chunk /// carries the first generated token, non-final chunks only advance the /// KV position. - fn apply_prefill_result(&mut self, result: &PrefillRequestResult) -> Result<()> { + fn apply_prefill_result(&mut self, result: &PrefillRequestResult) -> ExecutionResult<()> { let rkv = self .request_kvs .get_mut(&result.request_id) - .expect("request must exist after prefill"); - if result.completed { + .ok_or_else(|| { + ExecutionError::unexpected_worker_response( + "prefill", + format!("unknown request id {:?}", result.request_id), + ) + })?; + recoverable(if result.completed { rkv.apply_prefill(result.first_token, self.kv_mgr.pool()) } else { rkv.apply_prefill_chunk(self.kv_mgr.pool()) - } + }) } // ── KV-offload LOAD (async CPU-tier prefetch) ────────────────────── @@ -1060,19 +1080,19 @@ impl Qwen3Executor { } fn wait_for_step_ack( - pending: Vec>>, + pending: Vec>>, op_name: &'static str, - ) -> Result<()> { + ) -> ExecutionResult<()> { for recv in pending { - match recv - .recv() - .map_err(|_| anyhow::anyhow!("tensor-parallel {op_name} worker dropped"))?? - { + let outcome = recv.recv().map_err(|_| { + ExecutionError::worker_response_dropped("tensor-parallel peer", op_name) + })?; + match outcome? { WorkerStepOutcome::Ack => {} other => { - return Err(anyhow::anyhow!( - "tensor-parallel {op_name} worker returned unexpected payload: {}", - other.kind() + return Err(ExecutionError::unexpected_worker_response( + op_name, + other.kind(), )); } } @@ -1080,7 +1100,7 @@ impl Qwen3Executor { Ok(()) } - fn run_step(&self, step: &StepCommand) -> Result { + fn run_step(&self, step: &StepCommand) -> ExecutionResult { let primary = self.primary.run_step(step.clone(), true)?; let mut pending = Vec::with_capacity(self.workers.len()); for worker in &self.workers { @@ -1088,12 +1108,31 @@ impl Qwen3Executor { } let primary_result = primary .recv() - .map_err(|_| anyhow::anyhow!("primary worker dropped step response"))??; + .map_err(|_| ExecutionError::worker_response_dropped("primary", step.kind()))?; + let primary_result = primary_result?; Self::wait_for_step_ack(pending, step.kind())?; Ok(primary_result) } } +fn recoverable_error(error: anyhow::Error) -> ExecutionError { + ExecutionError::step_failed(error.to_string()) +} + +fn recoverable(result: Result) -> ExecutionResult { + result.map_err(recoverable_error) +} + +fn panic_message(payload: Box) -> String { + if let Some(message) = payload.downcast_ref::<&str>() { + (*message).to_string() + } else if let Some(message) = payload.downcast_ref::() { + message.clone() + } else { + "non-string panic payload".to_string() + } +} + /// Build the KV-offload engine for the single-GPU path, or `None` when offload /// is disabled. Registers the fused KV buffer with pegaflow against the model's /// device/stream — must be called while that stream is still owned by the model @@ -1321,12 +1360,12 @@ impl ModelExecutor for Qwen3Executor { done } - fn execute_prefill(&mut self, plan: PrefillPlan<'_>) -> Result { + fn execute_prefill(&mut self, plan: PrefillPlan<'_>) -> ExecutionResult { // 1. Create RequestKvs (first chunk only), clamp chunk budgets, // schedule KV for this step's tokens let mut requests = plan.requests.to_vec(); for req in &mut requests { - self.schedule_prefill_chunk(req)?; + recoverable(self.schedule_prefill_chunk(req))?; } // 2. Build KvViews (seq_len = chunk_start + this chunk) @@ -1348,9 +1387,9 @@ impl ModelExecutor for Qwen3Executor { let result = match outcome { WorkerStepOutcome::Prefill(result) => result, other => { - return Err(anyhow::anyhow!( - "prefill returned unexpected: {}", - other.kind() + return Err(ExecutionError::unexpected_worker_response( + "prefill", + other.kind(), )); } }; @@ -1365,16 +1404,19 @@ impl ModelExecutor for Qwen3Executor { Ok(result) } - fn execute_decode(&mut self, plan: DecodePlan<'_>) -> Result { + fn execute_decode(&mut self, plan: DecodePlan<'_>) -> ExecutionResult { // 1. Schedule decode for all active requests for req in plan.requests { let rkv = self .request_kvs .get_mut(&req.request_id) - .ok_or_else(|| anyhow::anyhow!("missing RequestKv for {:?}", req.request_id))?; - rkv.schedule_decode(self.kv_mgr.pool()).map_err(|e| { - anyhow::anyhow!("schedule_decode failed for {:?}: {e}", req.request_id) - })?; + .ok_or_else(|| anyhow::anyhow!("missing RequestKv for {:?}", req.request_id)) + .map_err(recoverable_error)?; + rkv.schedule_decode(self.kv_mgr.pool()) + .map_err(|e| { + anyhow::anyhow!("schedule_decode failed for {:?}: {e}", req.request_id) + }) + .map_err(recoverable_error)?; } // 2. Build KvViews @@ -1396,9 +1438,9 @@ impl ModelExecutor for Qwen3Executor { let result = match outcome { WorkerStepOutcome::Decode(result) => result, other => { - return Err(anyhow::anyhow!( - "decode returned unexpected: {}", - other.kind() + return Err(ExecutionError::unexpected_worker_response( + "decode", + other.kind(), )); } }; @@ -1406,8 +1448,13 @@ impl ModelExecutor for Qwen3Executor { let rkv = self .request_kvs .get_mut(&req_result.request_id) - .expect("request must exist after decode"); - rkv.apply_decode(req_result.token, self.kv_mgr.pool())?; + .ok_or_else(|| { + ExecutionError::unexpected_worker_response( + "decode", + format!("unknown request id {:?}", req_result.request_id), + ) + })?; + recoverable(rkv.apply_decode(req_result.token, self.kv_mgr.pool()))?; } // 5. Offload any block this decode step just sealed (post-step-sync). for req_result in &result.requests { @@ -1417,12 +1464,12 @@ impl ModelExecutor for Qwen3Executor { Ok(result) } - fn execute_unified(&mut self, plan: UnifiedPlan<'_>) -> Result { + fn execute_unified(&mut self, plan: UnifiedPlan<'_>) -> ExecutionResult { // 1. Create RequestKvs for prefill requests (first chunk only), clamp // chunk budgets, schedule KV for this step's tokens let mut prefill_requests = plan.prefill_requests.to_vec(); for req in &mut prefill_requests { - self.schedule_prefill_chunk(req)?; + recoverable(self.schedule_prefill_chunk(req))?; } // Schedule decode for active requests @@ -1430,10 +1477,13 @@ impl ModelExecutor for Qwen3Executor { let rkv = self .request_kvs .get_mut(&req.request_id) - .ok_or_else(|| anyhow::anyhow!("missing RequestKv for {:?}", req.request_id))?; - rkv.schedule_decode(self.kv_mgr.pool()).map_err(|e| { - anyhow::anyhow!("schedule_decode failed for {:?}: {e}", req.request_id) - })?; + .ok_or_else(|| anyhow::anyhow!("missing RequestKv for {:?}", req.request_id)) + .map_err(recoverable_error)?; + rkv.schedule_decode(self.kv_mgr.pool()) + .map_err(|e| { + anyhow::anyhow!("schedule_decode failed for {:?}: {e}", req.request_id) + }) + .map_err(recoverable_error)?; } // 2. Build KvViews @@ -1461,9 +1511,9 @@ impl ModelExecutor for Qwen3Executor { let result = match outcome { WorkerStepOutcome::Unified(result) => result, other => { - return Err(anyhow::anyhow!( - "unified returned unexpected: {}", - other.kind() + return Err(ExecutionError::unexpected_worker_response( + "unified", + other.kind(), )); } }; @@ -1474,8 +1524,13 @@ impl ModelExecutor for Qwen3Executor { let rkv = self .request_kvs .get_mut(&req_result.request_id) - .expect("request must exist after unified decode"); - rkv.apply_decode(req_result.token, self.kv_mgr.pool())?; + .ok_or_else(|| { + ExecutionError::unexpected_worker_response( + "unified decode", + format!("unknown request id {:?}", req_result.request_id), + ) + })?; + recoverable(rkv.apply_decode(req_result.token, self.kv_mgr.pool()))?; } // 5. Offload sealed blocks from both halves (post-step-sync). for req_result in &result.prefill_requests { @@ -1937,7 +1992,7 @@ enum WorkerCommand { RunStep { step: StepCommand, collect_result: bool, - resp: channel::Sender>, + resp: channel::Sender>, }, LoadLoraAdapter { name: String, @@ -1997,9 +2052,21 @@ impl RankWorker { collect_result, resp, } => { - let result = - execute_step_on_lane(&mut lane, &step, collect_result); + let result = panic::catch_unwind(AssertUnwindSafe(|| { + execute_step_on_lane(&mut lane, &step, collect_result) + })); + let should_exit = result.is_err(); + let result = match result { + Ok(result) => result.map_err(recoverable_error), + Err(payload) => Err(ExecutionError::worker_panic( + format!("qwen3-tp-rank-{rank}"), + panic_message(payload), + )), + }; let _ = resp.send(result); + if should_exit { + break; + } } WorkerCommand::LoadLoraAdapter { name, @@ -2042,7 +2109,7 @@ impl RankWorker { &self, step: StepCommand, collect_result: bool, - ) -> Result>> { + ) -> ExecutionResult>> { let (resp_tx, resp_rx) = channel::bounded(1); self.tx .send(WorkerCommand::RunStep { @@ -2050,7 +2117,7 @@ impl RankWorker { collect_result, resp: resp_tx, }) - .map_err(|_| anyhow::anyhow!("tensor-parallel worker step channel closed"))?; + .map_err(|_| ExecutionError::worker_command_channel_closed("step"))?; Ok(resp_rx) } @@ -2103,3 +2170,93 @@ impl RankWorker { } } } + +#[cfg(test)] +mod worker_tests { + use super::*; + use openinfer_core::engine::ExecutionRecovery; + + fn empty_decode_step() -> StepCommand { + StepCommand::Decode { + requests: Vec::new(), + kv_views: Vec::new(), + sample_seed: 0, + } + } + + #[test] + fn worker_panic_is_reported_as_domain_fatal_then_worker_exits() { + let mut worker = RankWorker::spawn_test_panic_worker(); + + let response = worker + .run_step(empty_decode_step(), true) + .expect("dispatch panic step"); + let err = match response.recv().expect("panic response") { + Ok(_) => panic!("worker panic must be returned as an error"), + Err(err) => err, + }; + assert_eq!(err.recovery(), ExecutionRecovery::DomainFatal); + assert!(err.to_string().contains("injected worker panic")); + + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(1); + while !worker.handle.as_ref().expect("worker handle").is_finished() { + assert!( + std::time::Instant::now() < deadline, + "panic worker did not exit" + ); + std::thread::sleep(std::time::Duration::from_millis(10)); + } + + let dispatch = worker.run_step(empty_decode_step(), true); + assert!( + matches!( + dispatch, + Err(ExecutionError::WorkerCommandChannelClosed { .. }) + ), + "panic worker exits after reporting the fatal error" + ); + worker.shutdown(); + } +} + +#[cfg(test)] +impl RankWorker { + fn spawn_test_panic_worker() -> Self { + let (tx, rx) = channel::unbounded(); + let handle = thread::Builder::new() + .name("qwen3-test-panic-rank".into()) + .spawn(move || { + while let Ok(cmd) = rx.recv() { + match cmd { + WorkerCommand::RunStep { resp, .. } => { + let result = panic::catch_unwind(AssertUnwindSafe(|| { + panic!("injected worker panic") + })); + let result = match result { + Ok(()) => unreachable!("test worker always panics"), + Err(payload) => Err(ExecutionError::worker_panic( + "qwen3-test-panic-rank", + panic_message(payload), + )), + }; + let _ = resp.send(result); + break; + } + WorkerCommand::Shutdown => break, + WorkerCommand::LoadLoraAdapter { resp, .. } + | WorkerCommand::UnloadLoraAdapter { resp, .. } + | WorkerCommand::DiscardLoraAdapter { resp, .. } => { + let _ = resp.send(Err(anyhow::anyhow!( + "test panic worker does not handle LoRA commands" + ))); + } + } + } + }) + .expect("spawn test panic worker"); + Self { + tx, + handle: Some(handle), + } + } +} diff --git a/openinfer-qwen3-4b/src/lora.rs b/openinfer-qwen3-4b/src/lora.rs index 7227e74f..839c0b23 100644 --- a/openinfer-qwen3-4b/src/lora.rs +++ b/openinfer-qwen3-4b/src/lora.rs @@ -389,7 +389,7 @@ pub(crate) fn apply_lora_projection_delta_range( let mut rank_out = HiddenStates::zeros(ctx, projection.a.rows, token_len)?; ops::gemm_token_range_into_checked(ctx, &projection.a, input, token_offset, &mut rank_out)?; let mut delta = HiddenStates::zeros(ctx, projection.b.rows, token_len)?; - ops::gemm_into(ctx, &projection.b, &rank_out, &mut delta); + ops::gemm_into_checked(ctx, &projection.b, &rank_out, &mut delta)?; ops::scaled_add_rows_token_range_into(ctx, &delta, scale, out, row_offset, token_offset) } @@ -412,7 +412,7 @@ pub(crate) fn apply_lora_projection_delta_indexed( let mut rank_out = HiddenStates::zeros(ctx, projection.a.rows, token_count)?; ops::gemm_into_checked(ctx, &projection.a, &compact_input, &mut rank_out)?; let mut delta = HiddenStates::zeros(ctx, projection.b.rows, token_count)?; - ops::gemm_into(ctx, &projection.b, &rank_out, &mut delta); + ops::gemm_into_checked(ctx, &projection.b, &rank_out, &mut delta)?; ops::scaled_add_rows_indexed_into( ctx, &delta, diff --git a/openinfer-qwen3-4b/src/prefill.rs b/openinfer-qwen3-4b/src/prefill.rs index a0d65cd5..8bdc9312 100644 --- a/openinfer-qwen3-4b/src/prefill.rs +++ b/openinfer-qwen3-4b/src/prefill.rs @@ -9,6 +9,7 @@ use openinfer_core::kv_pool::KvLayout; use openinfer_core::ops; use openinfer_core::ops::PrefillPagedPlan; use openinfer_core::tensor::{DeviceContext, HiddenStates}; +use openinfer_kernels::ops::gemm_rows_into_checked; use openinfer_kv_cache::KvView; /// Pre-allocated scratch buffers for one prefill forward pass. @@ -103,14 +104,14 @@ impl Qwen3Model { // 2. QKV projections from fused qkv_proj let q_dim = layer.attention.q_dim; let kv_dim = layer.attention.kv_dim; - ops::gemm_rows_into( + gemm_rows_into_checked( &self.ctx, &layer.attention.qkv_proj, 0, q_dim, &bufs.normed, &mut bufs.q_batch, - ); + )?; self.apply_lora_projection_ranges( layer_idx, lora_groups, @@ -119,14 +120,14 @@ impl Qwen3Model { &mut bufs.q_batch, 0, )?; - ops::gemm_rows_into( + gemm_rows_into_checked( &self.ctx, &layer.attention.qkv_proj, q_dim, kv_dim, &bufs.normed, &mut bufs.k_batch, - ); + )?; self.apply_lora_projection_ranges( layer_idx, lora_groups, @@ -135,14 +136,14 @@ impl Qwen3Model { &mut bufs.k_batch, 0, )?; - ops::gemm_rows_into( + gemm_rows_into_checked( &self.ctx, &layer.attention.qkv_proj, q_dim + kv_dim, kv_dim, &bufs.normed, &mut bufs.v_batch, - ); + )?; self.apply_lora_projection_ranges( layer_idx, lora_groups, @@ -174,12 +175,12 @@ impl Qwen3Model { )?; // 4. O projection → bufs.o_buf (as o_batch) - ops::gemm_into( + ops::gemm_into_checked( &self.ctx, &layer.attention.o_proj, &bufs.attn_output, &mut bufs.o_buf, - ); + )?; self.apply_lora_projection_ranges( layer_idx, lora_groups, @@ -202,22 +203,22 @@ impl Qwen3Model { // 7. MLP: split gate/up GEMMs → silu_mul → down → bufs.o_buf let inter_dim = self.local_intermediate_size(); - ops::gemm_rows_into( + gemm_rows_into_checked( &self.ctx, &layer.mlp.gate_up_proj, 0, inter_dim, &bufs.normed, &mut bufs.gate_out, - ); - ops::gemm_rows_into( + )?; + gemm_rows_into_checked( &self.ctx, &layer.mlp.gate_up_proj, inter_dim, inter_dim, &bufs.normed, &mut bufs.up_out, - ); + )?; self.apply_lora_projection_ranges( layer_idx, lora_groups, @@ -235,12 +236,12 @@ impl Qwen3Model { 0, )?; ops::silu_mul_batch_into(&self.ctx, &bufs.gate_out, &bufs.up_out, &mut bufs.act_out)?; - ops::gemm_into( + ops::gemm_into_checked( &self.ctx, &layer.mlp.down_proj, &bufs.act_out, &mut bufs.o_buf, - ); + )?; self.apply_lora_projection_ranges( layer_idx, lora_groups, diff --git a/openinfer-qwen3-4b/src/scheduler.rs b/openinfer-qwen3-4b/src/scheduler.rs index 42b6c23a..5f94fa25 100644 --- a/openinfer-qwen3-4b/src/scheduler.rs +++ b/openinfer-qwen3-4b/src/scheduler.rs @@ -21,8 +21,8 @@ use tokio::sync::mpsc; use crate::executor::{ModelExecutor, Qwen3Executor, RequestId}; use crate::{Qwen3LoraOptions, Qwen3OffloadOptions}; use openinfer_core::engine::{ - EngineCommand, EngineControlRequest, EngineHandle, GenerateRequest, KvCapacity, TokenEvent, - TokenSink, + EngineCommand, EngineControlRequest, EngineHandle, EngineHealth, ExecutionRecovery, + GenerateRequest, KvCapacity, TokenEvent, TokenSink, }; use openinfer_core::sampler::SamplingParams; @@ -210,15 +210,24 @@ where block_size: executor.block_size(), }; let (submit_tx, submit_rx) = mpsc::unbounded_channel(); + let health = EngineHealth::new(); + let scheduler_health = health.clone(); thread::Builder::new() .name("scheduler".into()) .spawn(move || { - scheduler_loop(executor, submit_rx, seed, max_prefill_tokens); + scheduler_loop( + executor, + submit_rx, + scheduler_health, + seed, + max_prefill_tokens, + ); }) .expect("failed to spawn scheduler thread"); EngineHandle::new(submit_tx) + .with_health(health) .with_servable_len(servable) .with_kv_capacity(kv_capacity) } @@ -248,15 +257,24 @@ where block_size: executor.block_size(), }; let (command_tx, command_rx) = mpsc::unbounded_channel(); + let health = EngineHealth::new(); + let scheduler_health = health.clone(); thread::Builder::new() .name("scheduler".into()) .spawn(move || { - scheduler_loop_with_lora_control(executor, command_rx, seed, max_prefill_tokens); + scheduler_loop_with_lora_control( + executor, + command_rx, + scheduler_health, + seed, + max_prefill_tokens, + ); }) .expect("failed to spawn scheduler thread"); EngineHandle::new_with_command_channel(command_tx) + .with_health(health) .with_servable_len(servable) .with_kv_capacity(kv_capacity) } @@ -349,11 +367,73 @@ fn release_rejected(executor: &mut E, req: &PendingRequest) { } } +fn reject_submissions_after_fatal( + submit_rx: &mut mpsc::UnboundedReceiver, + reason: &str, +) -> bool { + let Some(req) = submit_rx.blocking_recv() else { + return false; + }; + send_generation_fatal(req, reason); + while let Ok(req) = submit_rx.try_recv() { + send_generation_fatal(req, reason); + } + true +} + +fn reject_commands_after_fatal( + command_rx: &mut mpsc::UnboundedReceiver, + reason: &str, +) -> bool { + let Some(command) = command_rx.blocking_recv() else { + return false; + }; + reject_engine_command(command, reason); + while let Ok(command) = command_rx.try_recv() { + reject_engine_command(command, reason); + } + true +} + +fn reject_engine_command(command: EngineCommand, reason: &str) { + match command { + EngineCommand::Generate(req) => send_generation_fatal(req, reason), + EngineCommand::Control(control) => fail_control_request(control, reason), + } +} + +fn send_generation_fatal(req: GenerateRequest, reason: &str) { + let _ = req.token_tx.send(TokenEvent::Error { + message: reason.to_string(), + prompt_tokens: req.prompt_tokens.len(), + completion_tokens: 0, + }); +} + +fn fail_control_requests(controls: impl IntoIterator, reason: &str) { + for control in controls { + fail_control_request(control, reason); + } +} + +fn fail_control_request(control: EngineControlRequest, reason: &str) { + match control { + EngineControlRequest::LoadLoraAdapter { response_tx, .. } + | EngineControlRequest::UnloadLoraAdapter { response_tx, .. } => { + let _ = response_tx.send(Err(reason.to_string())); + } + EngineControlRequest::ListLoraAdapters { response_tx } => { + let _ = response_tx.send(Err(reason.to_string())); + } + } +} + // ── Main loop ─────────────────────────────────────────────────────────── fn scheduler_loop( mut executor: E, mut submit_rx: mpsc::UnboundedReceiver, + health: EngineHealth, seed: u64, max_prefill_tokens: usize, ) where @@ -374,6 +454,14 @@ fn scheduler_loop( info!("Scheduler ready"); loop { + if let Some(reason) = health.unhealthy_reason() { + if !reject_submissions_after_fatal(&mut submit_rx, reason) { + info!("Scheduler: all handles dropped after fatal state, exiting"); + return; + } + continue; + } + // 1. Drain all incoming requests into deferred. while let Ok(req) = submit_rx.try_recv() { deferred.push(PendingRequest::from_scheduler_request( @@ -445,15 +533,38 @@ fn scheduler_loop( continue; }; let failure_targets = failure_targets_for(&active, &plan); - let artifacts = match execute_plan(&mut executor, &mut active, plan, &mut rng) { + let effects = match execute_plan(&mut executor, &mut active, plan, &mut rng) + .and_then(|artifacts| resolve_step(&executor, &active, artifacts)) + { Ok(v) => v, Err(e) => { warn!("Execution step failed: {e}"); - fail_touched_requests(&mut executor, &mut active, failure_targets, &e.to_string()); + match e.recovery() { + ExecutionRecovery::DomainFatal => { + let reason = e.to_string(); + health.mark_unhealthy(reason.clone()); + fail_execution_domain( + &mut executor, + &mut active, + &mut deferred, + &mut loading, + &mut prefilling, + failure_targets, + &reason, + ); + } + ExecutionRecovery::Recoverable => { + fail_touched_requests( + &mut executor, + &mut active, + failure_targets, + &e.to_string(), + ); + } + } continue; } }; - let effects = resolve_step(&executor, &active, artifacts); apply_effects(&mut executor, &mut active, &mut prefilling, effects); } } @@ -461,6 +572,7 @@ fn scheduler_loop( fn scheduler_loop_with_lora_control( mut executor: E, mut command_rx: mpsc::UnboundedReceiver, + health: EngineHealth, seed: u64, max_prefill_tokens: usize, ) where @@ -478,6 +590,14 @@ fn scheduler_loop_with_lora_control( info!("Scheduler ready with LoRA control"); loop { + if let Some(reason) = health.unhealthy_reason() { + if !reject_commands_after_fatal(&mut command_rx, reason) { + info!("Scheduler: all handles dropped after fatal state, exiting"); + return; + } + continue; + } + // 1. Drain incoming commands. Generation submitted after a pending // control command waits until that control command is handled at idle. while let Ok(command) = command_rx.try_recv() { @@ -594,15 +714,44 @@ fn scheduler_loop_with_lora_control( continue; }; let failure_targets = failure_targets_for(&active, &plan); - let artifacts = match execute_plan(&mut executor, &mut active, plan, &mut rng) { + let effects = match execute_plan(&mut executor, &mut active, plan, &mut rng) + .and_then(|artifacts| resolve_step(&executor, &active, artifacts)) + { Ok(v) => v, Err(e) => { warn!("Execution step failed: {e}"); - fail_touched_requests(&mut executor, &mut active, failure_targets, &e.to_string()); + match e.recovery() { + ExecutionRecovery::DomainFatal => { + let reason = e.to_string(); + health.mark_unhealthy(reason.clone()); + fail_execution_domain( + &mut executor, + &mut active, + &mut deferred, + &mut loading, + &mut prefilling, + failure_targets, + &reason, + ); + fail_pending_requests( + &mut executor, + post_control_deferred.drain(..), + &reason, + ); + fail_control_requests(pending_control.drain(..), &reason); + } + ExecutionRecovery::Recoverable => { + fail_touched_requests( + &mut executor, + &mut active, + failure_targets, + &e.to_string(), + ); + } + } continue; } }; - let effects = resolve_step(&executor, &active, artifacts); apply_effects(&mut executor, &mut active, &mut prefilling, effects); } } @@ -983,7 +1132,9 @@ fn fail_touched_requests( targets: Vec, message: &str, ) { + let mut failed_ids = HashSet::with_capacity(targets.len()); for target in targets { + failed_ids.insert(target.request_id); let _ = target.token_tx.send(TokenEvent::Error { message: message.to_string(), prompt_tokens: target.prompt_tokens, @@ -996,7 +1147,44 @@ fn fail_touched_requests( ); } } - active.clear(); + active.retain(|req| !failed_ids.contains(&req.request_id)); +} + +fn fail_pending_requests( + executor: &mut impl ModelExecutor, + requests: impl IntoIterator, + message: &str, +) { + for req in requests { + let _ = req.token_tx.send(TokenEvent::Error { + message: message.to_string(), + prompt_tokens: req.prompt_tokens.len(), + completion_tokens: 0, + }); + if let Err(error) = executor.drop_request(req.request_id) { + warn!( + "failed to drop pending request state after fatal error for {:?}: {error}", + req.request_id + ); + } + } +} + +fn fail_execution_domain( + executor: &mut impl ModelExecutor, + active: &mut Vec, + deferred: &mut Vec, + loading: &mut Vec, + prefilling: &mut Vec, + failure_targets: Vec, + message: &str, +) { + fail_touched_requests(executor, active, failure_targets, message); + let remaining_active: Vec<_> = active.iter().map(active_failure_target).collect(); + fail_touched_requests(executor, active, remaining_active, message); + fail_pending_requests(executor, deferred.drain(..), message); + fail_pending_requests(executor, loading.drain(..), message); + fail_pending_requests(executor, prefilling.drain(..), message); } #[cfg(test)] @@ -1007,7 +1195,8 @@ mod tests { use anyhow::Result; use openinfer_core::engine::{ - EngineControlError, LoadLoraAdapterRequest, UnloadLoraAdapterRequest, + EngineControlError, ExecutionError, ExecutionResult, LoadLoraAdapterRequest, + UnloadLoraAdapterRequest, }; use openinfer_kv_cache::BlockPool; @@ -1017,6 +1206,14 @@ mod tests { PrefillStepItem, UnifiedPlan, UnifiedResult, }; + fn recoverable_error(error: anyhow::Error) -> ExecutionError { + ExecutionError::step_failed(error.to_string()) + } + + fn recoverable(result: Result) -> ExecutionResult { + result.map_err(recoverable_error) + } + struct FakeExecutor { block_size: usize, cached_tokens: usize, @@ -1028,6 +1225,7 @@ mod tests { // executor's kv_position so multi-chunk scheduling is exercised). prefill_positions: HashMap, fail_decode_once: bool, + fatal_decode_once: bool, decode_delay: Duration, loaded_lora_adapters: HashSet, dropped: Arc>>, @@ -1049,6 +1247,7 @@ mod tests { held_tokens: HashMap::new(), prefill_positions: HashMap::new(), fail_decode_once: false, + fatal_decode_once: false, decode_delay: Duration::ZERO, loaded_lora_adapters: HashSet::new(), dropped, @@ -1070,6 +1269,11 @@ mod tests { self } + fn with_fatal_decode_failure(mut self) -> Self { + self.fatal_decode_once = true; + self + } + fn with_max_context_tokens(mut self, max_context_tokens: usize) -> Self { self.max_context_tokens = max_context_tokens; self @@ -1211,7 +1415,7 @@ mod tests { Ok(()) } - fn execute_prefill(&mut self, plan: PrefillPlan<'_>) -> Result { + fn execute_prefill(&mut self, plan: PrefillPlan<'_>) -> ExecutionResult { self.prefill_batches.lock().unwrap().push( plan.requests .iter() @@ -1225,7 +1429,7 @@ mod tests { .collect(), ); for req in plan.requests { - self.ensure_request_tokens(req.request_id, req.prompt_tokens.len())?; + recoverable(self.ensure_request_tokens(req.request_id, req.prompt_tokens.len()))?; } Ok(PrefillResult { requests: plan @@ -1239,13 +1443,19 @@ mod tests { fn execute_decode( &mut self, plan: DecodePlan<'_>, - ) -> Result { + ) -> ExecutionResult { if !self.decode_delay.is_zero() { std::thread::sleep(self.decode_delay); } + if self.fatal_decode_once { + self.fatal_decode_once = false; + return Err(ExecutionError::worker_panic("fake", "fake worker panic")); + } if self.fail_decode_once { self.fail_decode_once = false; - anyhow::bail!("fake decode KV capacity exhausted"); + return Err(ExecutionError::step_failed( + "fake decode KV capacity exhausted", + )); } self.decode_batches.lock().unwrap().push( @@ -1265,8 +1475,9 @@ mod tests { .held_tokens .get(&req.request_id) .copied() - .ok_or_else(|| anyhow::anyhow!("missing fake request state"))?; - self.ensure_request_tokens(req.request_id, current_tokens + 1)?; + .ok_or_else(|| anyhow::anyhow!("missing fake request state")) + .map_err(recoverable_error)?; + recoverable(self.ensure_request_tokens(req.request_id, current_tokens + 1))?; } Ok(crate::executor::DecodeResult { @@ -1282,7 +1493,7 @@ mod tests { }) } - fn execute_unified(&mut self, plan: UnifiedPlan<'_>) -> Result { + fn execute_unified(&mut self, plan: UnifiedPlan<'_>) -> ExecutionResult { self.prefill_batches.lock().unwrap().push( plan.prefill_requests .iter() @@ -1308,15 +1519,16 @@ mod tests { .collect(), ); for req in plan.prefill_requests { - self.ensure_request_tokens(req.request_id, req.prompt_tokens.len())?; + recoverable(self.ensure_request_tokens(req.request_id, req.prompt_tokens.len()))?; } for req in plan.decode_requests { let current_tokens = self .held_tokens .get(&req.request_id) .copied() - .ok_or_else(|| anyhow::anyhow!("missing fake request state"))?; - self.ensure_request_tokens(req.request_id, current_tokens + 1)?; + .ok_or_else(|| anyhow::anyhow!("missing fake request state")) + .map_err(recoverable_error)?; + recoverable(self.ensure_request_tokens(req.request_id, current_tokens + 1))?; } Ok(UnifiedResult { @@ -2262,6 +2474,67 @@ mod tests { ); } + #[test] + fn fatal_worker_error_marks_engine_unhealthy_and_rejects_future_work() { + let dropped = Arc::new(Mutex::new(Vec::new())); + let executor = FakeExecutor::new(4, Arc::clone(&dropped)).with_fatal_decode_failure(); + let handle = start_with_executor(executor, 42, DEFAULT_MAX_PREFILL_TOKENS); + + let (will_fatal, mut fatal_rx) = request(16, 2); + handle.submit(will_fatal).expect("submit will_fatal"); + assert!( + matches!( + recv_skipping_scheduled(&mut fatal_rx), + Some(TokenEvent::Token { id: 100, .. }) + ), + "prefill should emit before the worker-domain fatal decode" + ); + match recv_skipping_scheduled(&mut fatal_rx) { + Some(TokenEvent::Error { + message, + prompt_tokens, + completion_tokens, + }) => { + assert!(message.contains("fake worker panic")); + assert_eq!(prompt_tokens, 16); + assert_eq!(completion_tokens, 1); + } + _ => panic!("domain fatal should surface as TokenEvent::Error"), + } + assert!( + wait_until(Duration::from_secs(1), || dropped + .lock() + .unwrap() + .contains(&0)), + "fatal request state should be dropped" + ); + assert!( + matches!( + handle.readiness(), + openinfer_core::engine::EngineReadiness::Unhealthy { ref reason } + if reason.contains("fake worker panic") + ), + "engine readiness should expose the fatal execution domain" + ); + + let (after_fatal, mut after_rx) = request(16, 1); + handle + .submit(after_fatal) + .expect("scheduler should stay alive to reject explicitly after fatal"); + match recv_skipping_scheduled(&mut after_rx) { + Some(TokenEvent::Error { + message, + prompt_tokens, + completion_tokens, + }) => { + assert!(message.contains("fake worker panic")); + assert_eq!(prompt_tokens, 16); + assert_eq!(completion_tokens, 0); + } + _ => panic!("post-fatal work should be rejected with a clear error"), + } + } + #[test] fn active_receiver_drop_releases_request_state() { let dropped = Arc::new(Mutex::new(Vec::new())); diff --git a/openinfer-qwen3-4b/src/scheduler/plan.rs b/openinfer-qwen3-4b/src/scheduler/plan.rs index a4794f4b..7b6a27f2 100644 --- a/openinfer-qwen3-4b/src/scheduler/plan.rs +++ b/openinfer-qwen3-4b/src/scheduler/plan.rs @@ -1,6 +1,7 @@ -use anyhow::Result; use rand::rngs::StdRng; +use openinfer_core::engine::ExecutionResult; + use crate::executor::{ DecodePlan, DecodeResult, DecodeStepItem, ModelExecutor, PrefillPlan, PrefillResult, PrefillStepItem, UnifiedPlan, UnifiedResult, @@ -53,7 +54,7 @@ pub(super) fn execute_plan( active: &mut [ActiveRequestState], plan: ExecutionPlan, rng: &mut StdRng, -) -> Result { +) -> ExecutionResult { match plan { ExecutionPlan::Prefill { pending } => { let scheduled_at_unix_s = openinfer_core::engine::unix_now_s(); diff --git a/openinfer-qwen3-4b/src/scheduler/resolve.rs b/openinfer-qwen3-4b/src/scheduler/resolve.rs index 1f91dcd5..48af8b36 100644 --- a/openinfer-qwen3-4b/src/scheduler/resolve.rs +++ b/openinfer-qwen3-4b/src/scheduler/resolve.rs @@ -1,5 +1,5 @@ use crate::executor::{DecodeRequestResult, ModelExecutor, PrefillRequestResult}; -use openinfer_core::engine::FinishReason; +use openinfer_core::engine::{ExecutionError, ExecutionResult, FinishReason}; use super::effects::{DecodeEffect, PendingEffect, PromptEchoEffect, ScheduledEffect, StepEffects}; use super::plan::ExecutionArtifacts; @@ -9,19 +9,19 @@ pub(super) fn resolve_step( executor: &impl ModelExecutor, active: &[ActiveRequestState], artifacts: ExecutionArtifacts, -) -> StepEffects { +) -> ExecutionResult { match artifacts { ExecutionArtifacts::Prefill { pending, result, scheduled_at_unix_s, } => resolve_prefill_outputs(executor, pending, result.requests, scheduled_at_unix_s), - ExecutionArtifacts::Decode { result } => StepEffects { + ExecutionArtifacts::Decode { result } => Ok(StepEffects { scheduled: Vec::new(), prompt_echoes: Vec::new(), pending: Vec::new(), - decode: resolve_decode_outputs(executor, active, &result.requests), - }, + decode: resolve_decode_outputs(executor, active, &result.requests)?, + }), ExecutionArtifacts::Unified { pending, result, @@ -32,9 +32,9 @@ pub(super) fn resolve_step( pending, result.prefill_requests, scheduled_at_unix_s, - ); - effects.decode = resolve_decode_outputs(executor, active, &result.decode_requests); - effects + )?; + effects.decode = resolve_decode_outputs(executor, active, &result.decode_requests)?; + Ok(effects) } } } @@ -44,13 +44,32 @@ fn resolve_prefill_outputs( pending: Vec, request_results: Vec, scheduled_at_unix_s: f64, -) -> StepEffects { +) -> ExecutionResult { + if pending.len() != request_results.len() { + return Err(ExecutionError::unexpected_worker_response( + "prefill resolve", + format!( + "result count {} does not match request count {}", + request_results.len(), + pending.len() + ), + )); + } + let mut effects = StepEffects::empty(); for (mut req, result) in pending.into_iter().zip(request_results) { // Results are matched to requests positionally; a misalignment here // would deliver request A's tokens to request B, so fail loudly in // release builds too. - assert_eq!(req.request_id, result.request_id); + if req.request_id != result.request_id { + return Err(ExecutionError::unexpected_worker_response( + "prefill resolve", + format!( + "result request id {:?} does not match pending {:?}", + result.request_id, req.request_id + ), + )); + } let prompt_len = req.prompt_tokens.len(); // Fire Scheduled on the request's first chunk only: queue time ends @@ -124,46 +143,50 @@ fn resolve_prefill_outputs( }); } - effects + Ok(effects) } fn resolve_decode_outputs( executor: &impl ModelExecutor, active: &[ActiveRequestState], request_results: &[DecodeRequestResult], -) -> Vec { - request_results - .iter() - .map(|result| { - let req = active - .iter() - .find(|req| req.request_id == result.request_id) - .expect("decode request_id must exist in active set"); - let completion_tokens = req.generated_count + 1; - let is_eos = !req.params.ignore_eos && executor.is_stop_token(result.token); - let at_limit = completion_tokens >= req.max_tokens; - if is_eos { - DecodeEffect::Finish { - request_id: result.request_id, - finish_reason: FinishReason::Stop, - completion_tokens, - } - } else if at_limit { - DecodeEffect::EmitAndFinish { - request_id: result.request_id, - token: result.token, - logprob: result.logprob.clone(), - finish_reason: FinishReason::Length, - completion_tokens, - } - } else { - DecodeEffect::EmitAndContinue { - request_id: result.request_id, - token: result.token, - logprob: result.logprob.clone(), - completion_tokens, - } +) -> ExecutionResult> { + let mut effects = Vec::with_capacity(request_results.len()); + for result in request_results { + let req = active + .iter() + .find(|req| req.request_id == result.request_id) + .ok_or_else(|| { + ExecutionError::unexpected_worker_response( + "decode resolve", + format!("unknown request id {:?}", result.request_id), + ) + })?; + let completion_tokens = req.generated_count + 1; + let is_eos = !req.params.ignore_eos && executor.is_stop_token(result.token); + let at_limit = completion_tokens >= req.max_tokens; + effects.push(if is_eos { + DecodeEffect::Finish { + request_id: result.request_id, + finish_reason: FinishReason::Stop, + completion_tokens, } - }) - .collect() + } else if at_limit { + DecodeEffect::EmitAndFinish { + request_id: result.request_id, + token: result.token, + logprob: result.logprob.clone(), + finish_reason: FinishReason::Length, + completion_tokens, + } + } else { + DecodeEffect::EmitAndContinue { + request_id: result.request_id, + token: result.token, + logprob: result.logprob.clone(), + completion_tokens, + } + }); + } + Ok(effects) } diff --git a/openinfer-qwen3-4b/src/unified_forward.rs b/openinfer-qwen3-4b/src/unified_forward.rs index db44e155..ee01dfad 100644 --- a/openinfer-qwen3-4b/src/unified_forward.rs +++ b/openinfer-qwen3-4b/src/unified_forward.rs @@ -17,6 +17,7 @@ use openinfer_core::kv_pool::KvLayout; use openinfer_core::ops; use openinfer_core::ops::PrefillPagedPlan; use openinfer_core::tensor::HiddenStates; +use openinfer_kernels::ops::gemm_rows_into_checked; use openinfer_kv_cache::KvView; impl Qwen3Model { @@ -195,14 +196,14 @@ impl Qwen3Model { // ── 2. QKV projections from fused qkv_proj [all tokens] ───── let q_dim_l = layer.attention.q_dim; let kv_dim_l = layer.attention.kv_dim; - ops::gemm_rows_into( + gemm_rows_into_checked( &self.ctx, &layer.attention.qkv_proj, 0, q_dim_l, &bufs.normed, &mut bufs.q_batch, - ); + )?; self.apply_lora_projection_ranges( layer_idx, lora_groups, @@ -211,14 +212,14 @@ impl Qwen3Model { &mut bufs.q_batch, 0, )?; - ops::gemm_rows_into( + gemm_rows_into_checked( &self.ctx, &layer.attention.qkv_proj, q_dim_l, kv_dim_l, &bufs.normed, &mut bufs.k_batch, - ); + )?; self.apply_lora_projection_ranges( layer_idx, lora_groups, @@ -227,14 +228,14 @@ impl Qwen3Model { &mut bufs.k_batch, 0, )?; - ops::gemm_rows_into( + gemm_rows_into_checked( &self.ctx, &layer.attention.qkv_proj, q_dim_l + kv_dim_l, kv_dim_l, &bufs.normed, &mut bufs.v_batch, - ); + )?; self.apply_lora_projection_ranges( layer_idx, lora_groups, @@ -268,12 +269,12 @@ impl Qwen3Model { )?; // ── 6. O projection [all tokens] ───────────────────────────── - ops::gemm_into( + ops::gemm_into_checked( &self.ctx, &layer.attention.o_proj, &bufs.attn_output, &mut bufs.o_buf, - ); + )?; self.apply_lora_projection_ranges( layer_idx, lora_groups, @@ -294,22 +295,22 @@ impl Qwen3Model { &mut bufs.normed, )?; - ops::gemm_rows_into( + gemm_rows_into_checked( &self.ctx, &layer.mlp.gate_up_proj, 0, self.local_intermediate_size(), &bufs.normed, &mut bufs.gate_out, - ); - ops::gemm_rows_into( + )?; + gemm_rows_into_checked( &self.ctx, &layer.mlp.gate_up_proj, self.local_intermediate_size(), self.local_intermediate_size(), &bufs.normed, &mut bufs.up_out, - ); + )?; self.apply_lora_projection_ranges( layer_idx, lora_groups, @@ -327,12 +328,12 @@ impl Qwen3Model { 0, )?; ops::silu_mul_batch_into(&self.ctx, &bufs.gate_out, &bufs.up_out, &mut bufs.act_out)?; - ops::gemm_into( + ops::gemm_into_checked( &self.ctx, &layer.mlp.down_proj, &bufs.act_out, &mut bufs.o_buf, - ); + )?; self.apply_lora_projection_ranges( layer_idx, lora_groups, diff --git a/openinfer-vllm-frontend/src/health.rs b/openinfer-vllm-frontend/src/health.rs new file mode 100644 index 00000000..2157873d --- /dev/null +++ b/openinfer-vllm-frontend/src/health.rs @@ -0,0 +1,95 @@ +use std::sync::{Arc, OnceLock}; + +use axum::Json; +use axum::extract::{Request, State}; +use axum::http::StatusCode; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; +use openinfer_engine::engine::{EngineHealth, EngineReadiness}; + +pub(crate) type HealthProbe = Arc>; + +pub(crate) async fn guard_health_request( + State(probe): State, + req: Request, + next: Next, +) -> Response { + if req.uri().path() != "/health" { + return next.run(req).await; + } + + match probe.get().map(EngineHealth::readiness) { + Some(EngineReadiness::Healthy) => { + Json(serde_json::json!({ "status": "ok" })).into_response() + } + Some(EngineReadiness::Unhealthy { reason }) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "status": "unhealthy", + "reason": reason, + })), + ) + .into_response(), + None => next.run(req).await, + } +} + +#[cfg(test)] +mod tests { + use axum::Router; + use axum::body::{Body, to_bytes}; + use axum::http::{Request, StatusCode}; + use axum::middleware::from_fn_with_state; + use axum::routing::get; + use tower::ServiceExt; + + use super::*; + + async fn get_health(router: Router) -> (StatusCode, serde_json::Value) { + let response = router + .oneshot( + Request::builder() + .uri("/health") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + let status = response.status(); + let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let body = serde_json::from_slice(&bytes).unwrap(); + (status, body) + } + + fn router(probe: HealthProbe) -> Router { + Router::new() + .route( + "/health", + get(|| async { Json(serde_json::json!({"status": "upstream"})) }), + ) + .layer(from_fn_with_state(probe, guard_health_request)) + } + + #[tokio::test] + async fn health_guard_reports_unhealthy_engine() { + let probe = HealthProbe::default(); + let health = EngineHealth::new(); + health.mark_unhealthy("worker died"); + assert!(probe.set(health).is_ok()); + + let (status, body) = get_health(router(probe)).await; + assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE); + assert_eq!(body["status"], "unhealthy"); + assert_eq!(body["reason"], "worker died"); + } + + #[tokio::test] + async fn health_guard_reports_healthy_engine() { + let probe = HealthProbe::default(); + assert!(probe.set(EngineHealth::new()).is_ok()); + + let (status, body) = get_health(router(probe)).await; + assert_eq!(status, StatusCode::OK); + assert_eq!(body["status"], "ok"); + } +} diff --git a/openinfer-vllm-frontend/src/lib.rs b/openinfer-vllm-frontend/src/lib.rs index 8d0549a1..3ab98891 100644 --- a/openinfer-vllm-frontend/src/lib.rs +++ b/openinfer-vllm-frontend/src/lib.rs @@ -20,11 +20,13 @@ use vllm_server::{ use openinfer_engine::engine::EngineHandle; mod bridge; +mod health; mod lora; mod sampling_guard; mod wire; use bridge::{LocalEngineBridge, ipc_endpoint, local_ipc_namespace}; +use health::{HealthProbe, guard_health_request}; use lora::{bad_request, load_startup_lora_modules, lora_openai_routes, lora_routes}; use sampling_guard::{ServableCap, guard_generation_request}; @@ -157,10 +159,12 @@ where // failure it cancels the server so the error surfaces instead of hanging // in the registration wait. let servable_cap = ServableCap::default(); + let health_probe = HealthProbe::default(); let server_shutdown = shutdown.child_token(); let bridge_shutdown = shutdown.child_token(); let engine_task = tokio::spawn({ let servable_cap = servable_cap.clone(); + let health_probe = health_probe.clone(); let server_shutdown = server_shutdown.clone(); let bridge_shutdown = bridge_shutdown.clone(); let input_address = input_address.clone(); @@ -175,6 +179,10 @@ where }; let servable_limit = handle.servable_len().map(|cap| max_model_len.min(cap)); servable_cap.set(servable_limit); + assert!( + health_probe.set(handle.health()).is_ok(), + "engine health probe must be set exactly once" + ); let bridge = LocalEngineBridge { input_address, output_address, @@ -223,7 +231,9 @@ where }; let extend_router = move |router: Router| { - extend_router(router).layer(from_fn_with_state(servable_cap, guard_generation_request)) + extend_router(router) + .layer(from_fn_with_state(servable_cap, guard_generation_request)) + .layer(from_fn_with_state(health_probe, guard_health_request)) }; let result = vllm_server::serve_with_router_extension(config, server_shutdown, extend_router).await;