-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Description
Feature request
Add an async_stopping_criteria flag to GenerationConfig that performs stopping criteria checks asynchronously on a separate CUDA stream. This reduces GPU-CPU synchronization overhead during autoregressive text generation by allowing the model to continue generating tokens while stopping criteria (EOS detection, max_length, custom criteria) are being evaluated in parallel.
Key implementation details:
- Uses a separate CUDA stream for stopping criteria evaluation
- Employs pinned (page-locked) CPU memory for efficient GPU-CPU communication without explicit synchronization
- Implements batched polling to minimize overhead (only polls for async results every N tokens)
- Gracefully falls back to synchronous behavior on CPU
Motivation
GPU-CPU synchronization during stopping criteria checks creates a significant bottleneck in autoregressive text generation. Each iteration of the generation loop currently requires:
- Running the forward pass on GPU
- Synchronizing to check stopping criteria (e.g.,
unfinished_sequences.max() == 0) - Only then proceeding to the next token
This synchronization blocks the GPU while the CPU evaluates stopping conditions. By moving these checks to a separate CUDA stream, we can overlap the stopping criteria evaluation with subsequent forward passes, significantly improving throughput.
Benchmark Results (using utf8-lm-tiny model, 200 new tokens):
| Mode | Tokens/sec | Speedup |
|---|---|---|
| Sync (baseline) | 80.92 | 1.00x |
| Async | 137.06 | 1.69x |
Your contribution
I have implemented this feature and submitted a PR: #43085
The implementation includes:
- New
async_stopping_criteriaparameter inGenerationConfig AsyncStoppingCriteriaListclass that wrapsStoppingCriteriaListfor async execution- Integration with the generation loop in
utils.py - Comprehensive test coverage (10 tests) for:
- Async/sync output equivalence
- Various stopping criteria (EOS, max_length, custom)
- Different batch sizes
- CPU fallback behavior