Skip to content

TorchScript greedy decode step + CUDA GC tuning#10

Merged
TroyHernandez merged 1 commit into
mainfrom
jit-decode
Jun 17, 2026
Merged

TorchScript greedy decode step + CUDA GC tuning#10
TroyHernandez merged 1 commit into
mainfrom
jit-decode

Conversation

@TroyHernandez

Copy link
Copy Markdown
Contributor

Applies the GC and JIT speedups from the recent chatterbox work to whisper.

JIT greedy decode (R/decode_jit.R)

Each generated token's full decoder forward (per layer: self-attention, cross-attention, FFN; plus the final LayerNorm) now runs as one jit_compile'd TorchScript call. R keeps the eager prefill, the sampling/timestamp logic, and the loop shell.

Same motivation as chatterbox's t3_inference_jit: even optimally-written eager R hits a per-op R->lantern dispatch floor (hundreds of ops/token). Collapsing the per-token forward into one libtorch call removes it without compiled code or linking against torch's private libraries.

Whisper's step differs from a Llama step in three ways, all handled: LayerNorm with bias (hand-rolled in float for fp16 parity), biases on every projection, and a cross-attention block whose encoder K/V are cached once and passed in as stacked tensors.

  • On by default via the new jit arg to transcribe() / whisper_pipeline().
  • Gated to CUDA greedy non-word-timestamp runs. CPU, beam search, and word timestamps stay on the eager decoder, so R CMD check (CPU) is unchanged.

GC tuning (whisper_tune_gc())

Opt-in helper that raises torch's CUDA allocator GC floor (torch.cuda_allocator_reserved_rate) to the model's footprint as a fraction of VRAM and lifts torch.threshold_call_gc off its default, so GC stops firing on nearly every allocation during inference. Call before load_whisper_model(). No-op off CUDA; only sets options that aren't already set.

Verification

  • Token-for-token identical to eager greedy, with and without timestamp rules (logprob diff 0.006, pure fp16 rounding). New test_decode_jit.R (CUDA + at_home gated).
  • ~2.5x faster end-to-end on medium for the jfk clip; more on longer transcripts where per-token dispatch dominates. Text identical.
  • Full suite: 199 results, all OK.

Version 0.3.0 -> 0.3.0.1.

Greedy decoding on CUDA now runs each token's full decoder forward (all
layers' self-attention, cross-attention, FFN, and the final LayerNorm) as
one jit_compile'd TorchScript call instead of dozens of dispatched
R->torch calls per token. Same motivation as chatterbox's
t3_inference_jit: even lean eager R hits a per-op dispatch floor, and
collapsing the per-token forward into one libtorch call removes it
without compiled code. Token-for-token equivalent to the eager path
(test_decode_jit.R); ~2.5x faster end-to-end on medium for a short clip.
On by default via the new jit arg to transcribe()/whisper_pipeline();
gated to CUDA greedy non-word-timestamp runs, so CPU/beam/word-timestamp
paths and R CMD check are unchanged.

Also adds whisper_tune_gc(): opt-in helper raising torch's CUDA allocator
GC floor to the model's footprint as a fraction of VRAM, so GC stops
firing on nearly every allocation during inference. No-op off CUDA,
only-if-unset. Bumps to 0.3.0.1.
@TroyHernandez TroyHernandez merged commit 9263a32 into main Jun 17, 2026
2 checks passed
@TroyHernandez TroyHernandez deleted the jit-decode branch June 17, 2026 13:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant