Skip to content

feat(kernels): verify-side target probs + chain rejection sampling primitives (#512)#534

Open
n-WN wants to merge 3 commits into
openinfer-project:mainfrom
n-WN:feat/spec-rejection-sampling
Open

feat(kernels): verify-side target probs + chain rejection sampling primitives (#512)#534
n-WN wants to merge 3 commits into
openinfer-project:mainfrom
n-WN:feat/spec-rejection-sampling

Conversation

@n-WN

@n-WN n-WN commented Jul 3, 2026

Copy link
Copy Markdown
Contributor

Slice 1 of #512 — the two GPU primitives non-greedy speculative decoding needs, as a self-contained kernels-layer PR. Executor wiring (verify-path integration, gate relaxation, the statistical equivalence gate) follows as the next slice on top.

gpu_verify_probs_into — the target distribution itself

The batched sampling pipeline's gather + softmax + top-k/top-p renorm, stopped before the terminal sampling kernel: the post-filter probabilities land in the caller's n_rows × vocab buffer, filtered tokens as exact zeros (that matters for the rejection residual). Distribution-equivalent to the fused sampling fast path — it filters at draw time, this filters then draws; the law over tokens is identical, which is all rejection sampling's correctness needs.

min_p rows are rejected fail-loud: min_p is applied inside the sampling kernel, not as a renorm, so that target distribution is not representable here — such requests stay off the speculative path (gate keeps them out in the next slice).

gpu_spec_accept_into — chain rejection sampling

Wraps the vendored FlashInfer ChainSpeculativeSampling: accept draft i with min(1, p_target/q_draft); the first rejection resamples from relu(target − draft) renormalized and stops; full acceptance emits the bonus token from the target at position K; the tail is -1-filled — exactly the "longest accepted prefix + one model token" contract accept_greedy already has, so the executor seam stays shape-compatible.

onehot_draft: DFlash proposes greedily, i.e. the proposal is the degenerate q(x) = δ(x − draft). Under that proposal, acceptance reduces to min(1, p_target(draft)) and the residual is the target with the draft token's mass removed — still distribution-exact. Deriving the one-hot rows on device means no draft-side proposal-probability plumbing is needed at all for the current proposer (real q support is already in the signature for a future sampled drafter).

A semantic trap, pinned in doc + test

The kernel returns two counters with different meanings: emitted is the accepted-prefix length — the commit signal; accepted keeps counting hypothetical acceptances past the first rejection (FlashInfer's acceptance-rate telemetry). Committing on accepted would emit wrong tokens; the wrapper doc says so and the test asserts the distinction.

Verification (H100, CUDA 12.9) — 3/3, corner-point deterministic (not statistical)

  • verify_probs_renorm_matches_the_sampling_law: closed-form softmax of {2,1,0} matches to 5e-3; top_k=2 renorm leaves the filtered token an exact 0.0 and renormalizes survivors to {0.7311, 0.2689}; rows sum to 1.
  • verify_probs_rejects_min_p_rows: fail-loud path.
  • spec_accept_full_acceptance_and_certain_rejection: a target with mass 1.0 on each draft forces full acceptance + bonus ([100, 200, 777], emitted=2); a target with zero mass on the draft forces first-step rejection where the residual is a single token, making the resample exact ([555, -1, -1], emitted=0).

🤖 Generated with Claude Code

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f162433592

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread openinfer-kernels/src/ops/sampling.rs Outdated
/// kernel, not as a renorm, so a min_p target distribution is not
/// representable here yet — callers must keep such requests off the
/// speculative path.
pub fn gpu_verify_probs_into(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Re-export the speculative sampling entry points

These new primitives are declared pub inside the private sampling module, but openinfer-kernels/src/ops.rs only re-exports the older sampling APIs (gpu_sample_batch_into, etc.). In that setup model crates such as the upcoming Qwen verify-path integration cannot call either gpu_verify_probs_into or gpu_spec_accept_into; they are only reachable from this module’s tests. Please add them to the public ops re-export list (and core re-exports if needed) so the kernels-layer API added here is usable outside openinfer-kernels.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right — they were pub only inside the private module, unreachable from model crates. Re-exported through ops in d0de3b1 (the upcoming verify-path slice consumes them from there).

@xiaguan

xiaguan commented Jul 4, 2026

Copy link
Copy Markdown
Collaborator

Thanks for this — the approach is sound: reusing FlashInfer's renorm pipeline + ChainSpeculativeSampling instead of reinventing rejection sampling, and the onehot_draft idea elegantly avoids q-probability plumbing for a greedy proposer. The emitted vs accepted semantic trap is well documented and tested. A few things that would make this land cleaner:

Suggested changes

1. sampling.rs exceeds 1k lines (would benefit from a split)

openinfer-kernels/src/ops/sampling.rs goes from 996 → 1234 lines with this PR, past our 1k soft limit. The two new functions + their tests (~336 lines) are self-contained spec-decoding concerns and would read better in a dedicated module like ops/spec_sampling.rs. Happy to help with the split if useful — it's mostly cut/paste + re-export.

2. C + Rust both duplicate the existing gpu_sample_batch renorm pipeline

The gather + softmax + top-k/top-p renorm sequence appears twice at each layer:

  • C: gpu_verify_probs_flashinfer_cuda (flashinfer_sampling.cu:120-162) and gpu_sample_batch_flashinfer_cuda (flashinfer_sampling.cu:164-242) share the same gather-grid construction, gather_cast_logits_f32_kernel launch, OnlineSoftmax, RadixTopKRenormProbMultiCTA, and TopPRenormProb — roughly 30 lines duplicated.
  • Rust: gpu_verify_probs_into (sampling.rs:134-245) and sample_uniform_batch_into (sampling.rs:334-493) share the same capacity checks, row/temperature/top_p validation, top_k clamping (if r.top_k > 0 && r.top_k < vocab), and the four memcpy_htod calls — roughly 35 lines duplicated.

A shared C helper renorm_probs_flashinfer(...) and a Rust helper prepare_sampling_params(...) returning (row_indices, temperature, top_k, top_p, has_top_k_filter, has_top_p_filter) would let both entry points share the pipeline. This also makes future renorm-stage changes (e.g. min_p support) a single edit.

3. onehot_draft hardcoded to 1

sampling.rs:316 passes a literal 1 for onehot_draft, even though the FFI signature already accepts it (ffi/shared.rs:140). The PR description notes real-q support is intended for a future sampled drafter — exposing onehot_draft: bool in the Rust API now would avoid a signature change in slice 2, and true as i32 is at least self-documenting.

Minor notes (optional)

  • gpu_spec_accept_into doesn't check vocab > 0 (sampling.rs:284-294): with vocab == 0 the buffer-size checks all pass (batch * K * 0 = 0), but ChainSpeculativeSampling's behavior on vocab_size=0 is unverified. A one-line ensure!(vocab > 0, ...) would match gpu_verify_probs_into's implicit guarantee via logits.hidden_dim == scratch.vocab.
  • Test coverage is corner-point only: spec_accept_full_acceptance_and_certain_rejection proves the kernel doesn't crash and handles the deterministic extremes, but rejection sampling's core guarantee is distributional correctness. A partial-acceptance test (draft with ~50% target probability, asserting acceptance frequency converges over many seeds) would strengthen confidence — though understandable if deferred given the difficulty of statistical tests on GPU.

The seam design (shape-compatible with accept_greedy's "longest accepted prefix + one model token" contract) is clean. Once the file split and the duplication are addressed, this looks good to merge.

n-WN added 2 commits July 5, 2026 01:38
…imitives (openinfer-project#512)

The two GPU primitives non-greedy speculative decoding needs, verified on
H100 (CUDA 12.9), 3/3 deterministic tests:

gpu_verify_probs_into — the batched sampling pipeline's gather + softmax +
top-k/top-p renorm WITHOUT the terminal sampling kernel: the post-filter
target distribution lands in the caller's buffer, filtered tokens as exact
zeros. Distribution-equivalent to the fused sampling fast path (it filters
at draw time, this filters then draws — the law over tokens is identical).
min_p rows are rejected fail-loud: min_p is a sampling-time mask, not a
renorm, so that target distribution is not representable here yet.

gpu_spec_accept_into — wraps FlashInfer's ChainSpeculativeSampling:
accept draft i with min(1, p_target/q_draft), first rejection resamples
from relu(target - draft) renormalized, full acceptance emits the bonus
token, tail is -1-filled. onehot_draft derives q = delta(x - draft) on
device for a greedy/argmax proposer — the degenerate proposal under which
acceptance is min(1, p_target(draft)) and rejection sampling stays
distribution-exact, so DFlash's greedy drafts need no proposal-probability
plumbing.

One semantic trap pinned in doc + test: the kernel returns two counters
with different meanings — emitted is the accepted-prefix length (the
commit signal); accepted keeps counting hypothetical acceptances past the
first rejection (FlashInfer's acceptance-rate telemetry). Committing on
accepted would emit wrong tokens.

Tests are corner-point deterministic, not statistical: a target that puts
mass 1.0 on each draft forces full acceptance + bonus; a target with zero
mass on the draft forces first-step rejection with a single-token residual,
making the resample exact.
gpu_verify_probs_into / gpu_spec_accept_into were pub only inside the
private sampling module; the ops re-export list is what model crates see.
@n-WN n-WN force-pushed the feat/spec-rejection-sampling branch from d0de3b1 to e151669 Compare July 4, 2026 19:55
n-WN added a commit to n-WN/openinfer that referenced this pull request Jul 4, 2026
…view)

Address @xiaguan's openinfer-project#534 review:
- Move gpu_verify_probs_into + gpu_spec_accept_into (and their tests) into
  ops/spec_sampling.rs; sampling.rs 1336 -> 1060 lines, so the speculative
  code no longer drives the file over the soft size limit.
- Extract shared C helper gather_softmax_renorm_flashinfer (gather + softmax
  + optional top-k/top-p renorm), used by both the verify path and the min_p
  sampling path; extract Rust prepare_sampling_params shared by the batched
  sampler and the verify path. min_p stays caller-side (sampler packs it,
  verify rejects it).
- Expose onehot_draft as a bool parameter on gpu_spec_accept_into instead of
  hardcoding 1, so the slice-2 verify path can pass a real drafter's probs.
- Guard vocab > 0 in gpu_spec_accept_into.
- Add spec_accept_partial_acceptance_matches_the_rate_law: a 512-row
  Bernoulli(0.7) convergence test for the p in (0,1) acceptance law.

Verified on H100: cargo test -p openinfer-kernels --lib (sampling + spec)
passes 11/11.
@n-WN

n-WN commented Jul 4, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for the careful review — all four points addressed, plus the statistical test you suggested. Rebased onto main (dd7019e).

1. File split. Moved gpu_verify_probs_into + gpu_spec_accept_into (and their tests) into a new ops/spec_sampling.rs. sampling.rs drops from 1336 → 1060; the remainder is essentially the pre-existing sampler + upstream's new argmax_bf16_* — the speculative code no longer contributes to the overage.

2. Duplication.

  • C: extracted gather_softmax_renorm_flashinfer(...) (gather → softmax → optional top-k/top-p renorm). gpu_verify_probs_flashinfer_cuda now is that call; gpu_sample_batch_flashinfer_cuda's min_p path calls it with the filters on, the fast path with them off — behaviour-identical, ~30 duplicated lines gone.
  • Rust: extracted prepare_sampling_params(rows, seq_len, vocab) -> SamplingParams (validated row_indices/temperature/top_k/top_p + filter flags), shared by the batched sampler and the verify path. min_p stays caller-side since the two paths disagree on it (sampler packs it, verify rejects it).

3. onehot_draft. Now a bool parameter on gpu_spec_accept_into instead of the hardcoded 1, so the slice-2 verify path can pass a real drafter's draft_probs (onehot_draft=false) without touching the kernel.

4. vocab > 0. Added to the gpu_spec_accept_into guard.

5. Statistical test. Added spec_accept_partial_acceptance_matches_the_rate_law: a onehot proposer whose draft token holds 0.7 of the target mass, run over 512 i.i.d. rows (per-row philox decorrelates them), asserting the empirical acceptance rate lands in [0.63, 0.77] — 3σ around the exact min(1, 0.7) law, and checking accepted rows emit [A, bonus] while rejected rows emit the residual [B, -1]. This is the p∈(0,1) case the corner-point tests didn't cover.

…view)

Address @xiaguan's openinfer-project#534 review:
- Move gpu_verify_probs_into + gpu_spec_accept_into (and their tests) into
  ops/spec_sampling.rs; sampling.rs 1336 -> 1060 lines, so the speculative
  code no longer drives the file over the soft size limit.
- Extract shared C helper gather_softmax_renorm_flashinfer (gather + softmax
  + optional top-k/top-p renorm), used by both the verify path and the min_p
  sampling path; extract Rust prepare_sampling_params shared by the batched
  sampler and the verify path. min_p stays caller-side (sampler packs it,
  verify rejects it).
- Expose onehot_draft as a bool parameter on gpu_spec_accept_into instead of
  hardcoding 1, so the slice-2 verify path can pass a real drafter's probs.
- Guard vocab > 0 in gpu_spec_accept_into.
- Add spec_accept_partial_acceptance_matches_the_rate_law: a 512-row
  Bernoulli(0.7) convergence test for the p in (0,1) acceptance law.

Verified on H100: cargo test -p openinfer-kernels --lib (sampling + spec)
passes 11/11.
@n-WN n-WN force-pushed the feat/spec-rejection-sampling branch from e151669 to 92f9887 Compare July 4, 2026 20:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants