Skip to content

Commit

Permalink
Update spec_decode_worker
Browse files Browse the repository at this point in the history
  • Loading branch information
sroy745 committed Jul 2, 2024
1 parent 85d464f commit b3e1dda
Showing 1 changed file with 5 additions and 21 deletions.
26 changes: 5 additions & 21 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=speculative_config.
speculative_disable_by_batch_size,
disable_bonus_tokens_in_kv_cache=speculative_config.
disable_bonus_tokens_in_kv_cache,
draft_token_acceptance_method=speculative_config.
draft_token_acceptance_method,
typical_acceptance_sampler_posterior_threshold=speculative_config.
Expand Down Expand Up @@ -104,7 +102,6 @@ def create_worker(
scorer_worker: Worker,
draft_worker_kwargs: Dict[str, Any],
disable_by_batch_size: Optional[int],
disable_bonus_tokens_in_kv_cache: bool,
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
Expand All @@ -114,18 +111,16 @@ def create_worker(
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))


disable_bonus_tokens = False
if ngram_prompt_lookup_max > 0:
disable_bonus_tokens = False
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
elif draft_worker_kwargs[
"model_config"].hf_config.model_type == "mlp_speculator":
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
disable_bonus_tokens = False
else:
disable_bonus_tokens = disable_bonus_tokens_in_kv_cache
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'parallel_config']
draft_tp = draft_parallel_config.tensor_parallel_size
Expand Down Expand Up @@ -157,18 +152,12 @@ def create_worker(
return SpecDecodeWorker(proposer_worker,
scorer_worker,
disable_by_batch_size=disable_by_batch_size,
rejection_sampler=RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens),
disable_bonus_tokens_in_kv_cache=\
disable_bonus_tokens_in_kv_cache)
spec_decode_sampler=spec_decode_sampler)

def __init__(
self,
proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler,
disable_bonus_tokens_in_kv_cache: bool,
spec_decode_sampler: SpecDecodeBaseSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
Expand All @@ -184,10 +173,6 @@ def __init__(
Worker.
rejection_sampler: A Torch module used to perform modified rejection
sampling for speculative decoding.
disable_bonus_tokens_in_kv_cache: A boolean flag to control the use
of bonus tokens during speculative decoding in models that rely on KV
cache. If set to True, bonus tokens will be disabled and if set to False,
bonus tokens will be enabled.
spec_decode_sampler: A Torch module used to perform acceptance
sampling of the draft tokens in the verification step of
speculative decoding. Currently we support two different
Expand All @@ -207,13 +192,12 @@ def __init__(
self.spec_decode_sampler
) if metrics_collector is None else metrics_collector

self.probs_dtype = self.rejection_sampler.probs_dtype
self.token_id_dtype = self.rejection_sampler.token_id_dtype
self.probs_dtype = self.spec_decode_sampler.probs_dtype
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
# Tracks the sequence IDs that received a bonus token ID in
# their last forward pass. Needed only if KV cache is being
# used for token generation such as in the case of MultiStepWorker.
if (isinstance(self.proposer_worker, MultiStepWorker)
and not disable_bonus_tokens_in_kv_cache):
if (isinstance(self.proposer_worker, MultiStepWorker)):
self.seq_with_bonus_token_in_last_step = set()
else:
self.seq_with_bonus_token_in_last_step = None
Expand Down

0 comments on commit b3e1dda

Please sign in to comment.