From b3e1ddabea20d8bcf8c3749d6968eec5ebfacb4e Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Tue, 2 Jul 2024 06:02:08 +0000 Subject: [PATCH] Update spec_decode_worker --- vllm/spec_decode/spec_decode_worker.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index f9864da804b..135c90ae705 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -60,8 +60,6 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": draft_worker_kwargs=draft_worker_kwargs, disable_by_batch_size=speculative_config. speculative_disable_by_batch_size, - disable_bonus_tokens_in_kv_cache=speculative_config. - disable_bonus_tokens_in_kv_cache, draft_token_acceptance_method=speculative_config. draft_token_acceptance_method, typical_acceptance_sampler_posterior_threshold=speculative_config. @@ -104,7 +102,6 @@ def create_worker( scorer_worker: Worker, draft_worker_kwargs: Dict[str, Any], disable_by_batch_size: Optional[int], - disable_bonus_tokens_in_kv_cache: bool, draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, @@ -114,18 +111,16 @@ def create_worker( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) - + + disable_bonus_tokens = False if ngram_prompt_lookup_max > 0: - disable_bonus_tokens = False proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) elif draft_worker_kwargs[ "model_config"].hf_config.model_type == "mlp_speculator": proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) - disable_bonus_tokens = False else: - disable_bonus_tokens = disable_bonus_tokens_in_kv_cache draft_parallel_config: ParallelConfig = draft_worker_kwargs[ 'parallel_config'] draft_tp = draft_parallel_config.tensor_parallel_size @@ -157,18 +152,12 @@ def create_worker( return SpecDecodeWorker(proposer_worker, scorer_worker, disable_by_batch_size=disable_by_batch_size, - rejection_sampler=RejectionSampler( - disable_bonus_tokens=disable_bonus_tokens), - disable_bonus_tokens_in_kv_cache=\ - disable_bonus_tokens_in_kv_cache) spec_decode_sampler=spec_decode_sampler) def __init__( self, proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, - rejection_sampler: RejectionSampler, - disable_bonus_tokens_in_kv_cache: bool, spec_decode_sampler: SpecDecodeBaseSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, @@ -184,10 +173,6 @@ def __init__( Worker. rejection_sampler: A Torch module used to perform modified rejection sampling for speculative decoding. - disable_bonus_tokens_in_kv_cache: A boolean flag to control the use - of bonus tokens during speculative decoding in models that rely on KV - cache. If set to True, bonus tokens will be disabled and if set to False, - bonus tokens will be enabled. spec_decode_sampler: A Torch module used to perform acceptance sampling of the draft tokens in the verification step of speculative decoding. Currently we support two different @@ -207,13 +192,12 @@ def __init__( self.spec_decode_sampler ) if metrics_collector is None else metrics_collector - self.probs_dtype = self.rejection_sampler.probs_dtype - self.token_id_dtype = self.rejection_sampler.token_id_dtype + self.probs_dtype = self.spec_decode_sampler.probs_dtype + self.token_id_dtype = self.spec_decode_sampler.token_id_dtype # Tracks the sequence IDs that received a bonus token ID in # their last forward pass. Needed only if KV cache is being # used for token generation such as in the case of MultiStepWorker. - if (isinstance(self.proposer_worker, MultiStepWorker) - and not disable_bonus_tokens_in_kv_cache): + if (isinstance(self.proposer_worker, MultiStepWorker)): self.seq_with_bonus_token_in_last_step = set() else: self.seq_with_bonus_token_in_last_step = None