feat(kernels): verify-side target probs + chain rejection sampling primitives (#512)#534
feat(kernels): verify-side target probs + chain rejection sampling primitives (#512)#534n-WN wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f162433592
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| /// kernel, not as a renorm, so a min_p target distribution is not | ||
| /// representable here yet — callers must keep such requests off the | ||
| /// speculative path. | ||
| pub fn gpu_verify_probs_into( |
There was a problem hiding this comment.
Re-export the speculative sampling entry points
These new primitives are declared pub inside the private sampling module, but openinfer-kernels/src/ops.rs only re-exports the older sampling APIs (gpu_sample_batch_into, etc.). In that setup model crates such as the upcoming Qwen verify-path integration cannot call either gpu_verify_probs_into or gpu_spec_accept_into; they are only reachable from this module’s tests. Please add them to the public ops re-export list (and core re-exports if needed) so the kernels-layer API added here is usable outside openinfer-kernels.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Right — they were pub only inside the private module, unreachable from model crates. Re-exported through ops in d0de3b1 (the upcoming verify-path slice consumes them from there).
|
Thanks for this — the approach is sound: reusing FlashInfer's renorm pipeline + Suggested changes1.
|
…imitives (openinfer-project#512) The two GPU primitives non-greedy speculative decoding needs, verified on H100 (CUDA 12.9), 3/3 deterministic tests: gpu_verify_probs_into — the batched sampling pipeline's gather + softmax + top-k/top-p renorm WITHOUT the terminal sampling kernel: the post-filter target distribution lands in the caller's buffer, filtered tokens as exact zeros. Distribution-equivalent to the fused sampling fast path (it filters at draw time, this filters then draws — the law over tokens is identical). min_p rows are rejected fail-loud: min_p is a sampling-time mask, not a renorm, so that target distribution is not representable here yet. gpu_spec_accept_into — wraps FlashInfer's ChainSpeculativeSampling: accept draft i with min(1, p_target/q_draft), first rejection resamples from relu(target - draft) renormalized, full acceptance emits the bonus token, tail is -1-filled. onehot_draft derives q = delta(x - draft) on device for a greedy/argmax proposer — the degenerate proposal under which acceptance is min(1, p_target(draft)) and rejection sampling stays distribution-exact, so DFlash's greedy drafts need no proposal-probability plumbing. One semantic trap pinned in doc + test: the kernel returns two counters with different meanings — emitted is the accepted-prefix length (the commit signal); accepted keeps counting hypothetical acceptances past the first rejection (FlashInfer's acceptance-rate telemetry). Committing on accepted would emit wrong tokens. Tests are corner-point deterministic, not statistical: a target that puts mass 1.0 on each draft forces full acceptance + bonus; a target with zero mass on the draft forces first-step rejection with a single-token residual, making the resample exact.
gpu_verify_probs_into / gpu_spec_accept_into were pub only inside the private sampling module; the ops re-export list is what model crates see.
d0de3b1 to
e151669
Compare
…view) Address @xiaguan's openinfer-project#534 review: - Move gpu_verify_probs_into + gpu_spec_accept_into (and their tests) into ops/spec_sampling.rs; sampling.rs 1336 -> 1060 lines, so the speculative code no longer drives the file over the soft size limit. - Extract shared C helper gather_softmax_renorm_flashinfer (gather + softmax + optional top-k/top-p renorm), used by both the verify path and the min_p sampling path; extract Rust prepare_sampling_params shared by the batched sampler and the verify path. min_p stays caller-side (sampler packs it, verify rejects it). - Expose onehot_draft as a bool parameter on gpu_spec_accept_into instead of hardcoding 1, so the slice-2 verify path can pass a real drafter's probs. - Guard vocab > 0 in gpu_spec_accept_into. - Add spec_accept_partial_acceptance_matches_the_rate_law: a 512-row Bernoulli(0.7) convergence test for the p in (0,1) acceptance law. Verified on H100: cargo test -p openinfer-kernels --lib (sampling + spec) passes 11/11.
|
Thanks for the careful review — all four points addressed, plus the statistical test you suggested. Rebased onto 1. File split. Moved 2. Duplication.
3. 4. 5. Statistical test. Added |
…view) Address @xiaguan's openinfer-project#534 review: - Move gpu_verify_probs_into + gpu_spec_accept_into (and their tests) into ops/spec_sampling.rs; sampling.rs 1336 -> 1060 lines, so the speculative code no longer drives the file over the soft size limit. - Extract shared C helper gather_softmax_renorm_flashinfer (gather + softmax + optional top-k/top-p renorm), used by both the verify path and the min_p sampling path; extract Rust prepare_sampling_params shared by the batched sampler and the verify path. min_p stays caller-side (sampler packs it, verify rejects it). - Expose onehot_draft as a bool parameter on gpu_spec_accept_into instead of hardcoding 1, so the slice-2 verify path can pass a real drafter's probs. - Guard vocab > 0 in gpu_spec_accept_into. - Add spec_accept_partial_acceptance_matches_the_rate_law: a 512-row Bernoulli(0.7) convergence test for the p in (0,1) acceptance law. Verified on H100: cargo test -p openinfer-kernels --lib (sampling + spec) passes 11/11.
e151669 to
92f9887
Compare
Slice 1 of #512 — the two GPU primitives non-greedy speculative decoding needs, as a self-contained kernels-layer PR. Executor wiring (verify-path integration, gate relaxation, the statistical equivalence gate) follows as the next slice on top.
gpu_verify_probs_into— the target distribution itselfThe batched sampling pipeline's gather + softmax + top-k/top-p renorm, stopped before the terminal sampling kernel: the post-filter probabilities land in the caller's
n_rows × vocabbuffer, filtered tokens as exact zeros (that matters for the rejection residual). Distribution-equivalent to the fused sampling fast path — it filters at draw time, this filters then draws; the law over tokens is identical, which is all rejection sampling's correctness needs.min_prows are rejected fail-loud: min_p is applied inside the sampling kernel, not as a renorm, so that target distribution is not representable here — such requests stay off the speculative path (gate keeps them out in the next slice).gpu_spec_accept_into— chain rejection samplingWraps the vendored FlashInfer
ChainSpeculativeSampling: accept draftiwithmin(1, p_target/q_draft); the first rejection resamples fromrelu(target − draft)renormalized and stops; full acceptance emits the bonus token from the target at position K; the tail is-1-filled — exactly the "longest accepted prefix + one model token" contractaccept_greedyalready has, so the executor seam stays shape-compatible.onehot_draft: DFlash proposes greedily, i.e. the proposal is the degenerateq(x) = δ(x − draft). Under that proposal, acceptance reduces tomin(1, p_target(draft))and the residual is the target with the draft token's mass removed — still distribution-exact. Deriving the one-hot rows on device means no draft-side proposal-probability plumbing is needed at all for the current proposer (realqsupport is already in the signature for a future sampled drafter).A semantic trap, pinned in doc + test
The kernel returns two counters with different meanings:
emittedis the accepted-prefix length — the commit signal;acceptedkeeps counting hypothetical acceptances past the first rejection (FlashInfer's acceptance-rate telemetry). Committing onacceptedwould emit wrong tokens; the wrapper doc says so and the test asserts the distinction.Verification (H100, CUDA 12.9) — 3/3, corner-point deterministic (not statistical)
verify_probs_renorm_matches_the_sampling_law: closed-form softmax of{2,1,0}matches to 5e-3;top_k=2renorm leaves the filtered token an exact 0.0 and renormalizes survivors to{0.7311, 0.2689}; rows sum to 1.verify_probs_rejects_min_p_rows: fail-loud path.spec_accept_full_acceptance_and_certain_rejection: a target with mass 1.0 on each draft forces full acceptance + bonus ([100, 200, 777], emitted=2); a target with zero mass on the draft forces first-step rejection where the residual is a single token, making the resample exact ([555, -1, -1], emitted=0).🤖 Generated with Claude Code