diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 60dfe33f2918..740a227efcad 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -282,6 +282,6 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, if print_tokens: print(f'{i=} {baseline_tokens=}') print(f'{i=} {spec_tokens=}') - print(f'{i=} {baseline_token_ids=}') - print(f'{i=} {spec_token_ids=}') + #print(f'{i=} {baseline_token_ids=}') + #print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 94cc36f22875..5f870d6b4e23 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -643,12 +643,12 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, # Try a range of common k. for k in [1, 2, 3] ]) -@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("batch_size", [1, 8, 64]) @pytest.mark.parametrize( "output_len", [ # Use smaller output len for fast test. - 32, + 256, ]) @pytest.mark.parametrize("seed", [1]) def test_typical_acceptance_sampling(baseline_llm_generator, diff --git a/vllm/config.py b/vllm/config.py index fd2b52d9d7da..285248048c77 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -833,7 +833,6 @@ def maybe_create_spec_config( speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], - disable_bonus_tokens_in_kv_cache: bool, draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: Optional[float], typical_acceptance_sampler_posterior_alpha: Optional[float], @@ -873,10 +872,6 @@ def maybe_create_spec_config( window, if provided. ngram_prompt_lookup_min (Optional[int]): Min size of ngram token window, if provided. - disable_bonus_tokens_in_kv_cache (bool): A boolean flag to control - the use of bonus tokens during speculative decoding in models that - rely on KV cache. If set to True, bonus tokens will be disabled - and if set to False, bonus tokens will be enabled. draft_token_acceptance_method (str): The method to use for accepting draft tokens. This can take two possible values 'rejection_sampler' and 'typical_acceptance_sampler' @@ -1015,7 +1010,6 @@ def maybe_create_spec_config( speculative_disable_by_batch_size, ngram_prompt_lookup_max, ngram_prompt_lookup_min, - disable_bonus_tokens_in_kv_cache, draft_token_acceptance_method=draft_token_acceptance_method, typical_acceptance_sampler_posterior_threshold=\ typical_acceptance_sampler_posterior_threshold, @@ -1102,7 +1096,6 @@ def __init__( speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], - disable_bonus_tokens_in_kv_cache: bool, draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, @@ -1119,10 +1112,6 @@ def __init__( enqueue requests is larger than this value. ngram_prompt_lookup_max: Max size of ngram token window. ngram_prompt_lookup_min: Min size of ngram token window. - disable_bonus_tokens_in_kv_cache: A boolean flag to control the - use of bonus tokens during speculative decoding in models that - rely on KV cache. If set to True, bonus tokens will be disabled - and if set to False, bonus tokens will be enabled. draft_token_acceptance_method (str): The method to use for accepting draft tokens. This can take two possible values 'rejection_sampler' and 'typical_acceptance_sampler' @@ -1144,8 +1133,6 @@ def __init__( speculative_disable_by_batch_size self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 - self.disable_bonus_tokens_in_kv_cache = \ - disable_bonus_tokens_in_kv_cache self.draft_token_acceptance_method = draft_token_acceptance_method self.typical_acceptance_sampler_posterior_threshold = \ typical_acceptance_sampler_posterior_threshold diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c09054684b75..d4d8306ca078 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -100,7 +100,6 @@ class EngineArgs: speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None - disable_bonus_tokens_in_kv_cache: bool = True spec_decoding_acceptance_method: str = 'rejection_sampler' typical_acceptance_sampler_posterior_threshold: Optional[float] = None typical_acceptance_sampler_posterior_alpha: Optional[float] = None @@ -579,15 +578,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.ngram_prompt_lookup_min, help='Min size of window for ngram prompt lookup in speculative ' 'decoding.') - - parser.add_argument( - '--disable-bonus-tokens-in-kv-cache', - type=int, - default=EngineArgs.disable_bonus_tokens_in_kv_cache, - help='A boolean flag to control the use of bonus tokens during ' - 'speculative decoding in models that rely on KV cache. If set ' - 'to True, bonus tokens will be disabled and if set to False, ' - 'bonus tokens will be enabled.') parser.add_argument( '--spec-decoding-acceptance-method', @@ -781,7 +771,6 @@ def create_engine_config(self, ) -> EngineConfig: use_v2_block_manager=self.use_v2_block_manager, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, - disable_bonus_tokens_in_kv_cache=self.disable_bonus_tokens_in_kv_cache, draft_token_acceptance_method=\ self.spec_decoding_acceptance_method, typical_acceptance_sampler_posterior_threshold=self. diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 298cdc86bcf4..4e95edf8a646 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -62,13 +62,6 @@ def sampler_output( """ self._raise_if_unsupported(execute_model_req) - # Shallow copy input data so modifications (such as appending tokens) - # do not cause side-effects. - copied_seq_group_metadata_list = self._shallow_copy_inputs( - execute_model_req.seq_group_metadata_list) - copied_execute_model_req = execute_model_req.clone( - copied_seq_group_metadata_list) - # Assert enough KV space for sample_len tokens per sequence. self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list, sample_len) @@ -79,11 +72,11 @@ def sampler_output( # response to retain only the original sequences' responses. expanded_request, indices_of_seq_with_bonus_tokens =\ self._expand_execute_model_request( - copied_execute_model_req, seq_ids_with_bonus_token_in_last_step) + execute_model_req, seq_ids_with_bonus_token_in_last_step) # Run model sample_len times. model_outputs: List[SamplerOutput] = [] if isinstance(self.model_runner, TP1DraftModelRunner): - copied_execute_model_req.num_steps = sample_len + execute_model_req.num_steps = sample_len model_outputs = self.execute_model( execute_model_req=expanded_request) else: @@ -95,10 +88,14 @@ def sampler_output( ), "composing multistep workers not supported" model_output = model_output[0] - self._append_new_tokens(model_output, - copied_seq_group_metadata_list) + self._append_new_tokens( + model_output, + expanded_request.seq_group_metadata_list) model_outputs.append(model_output) - return model_outputs, True + + filtered_model_outputs = self._filter_model_output( + model_outputs, indices_of_seq_with_bonus_tokens) + return filtered_model_outputs, True @staticmethod def _expand_execute_model_request( @@ -134,9 +131,7 @@ def _expand_execute_model_request( #Create new sequences without the last bonus token. These new # sequence have the same sequence id as the original sequence. # We create a new sequence group and add them there. - updated_seq_group_without_bonus_token = \ - MultiStepWorker._shallow_copy_sequence_group_metadata( - seq_group) + updated_seq_group_without_bonus_token = copy.copy(seq_group) seq_group_without_bonus_token_data = { seq_id: SequenceData( prompt_token_ids=seq_group.seq_data[seq_id].prompt_token_ids, @@ -153,7 +148,8 @@ def _expand_execute_model_request( updated_seq_group_without_bonus_token.seq_data = seq_group_without_bonus_token_data updated_seq_group_metadata_list.append(updated_seq_group_without_bonus_token) # Add the original sequence group. - updated_seq_group_metadata_list.append(seq_group) + updated_seq_group_metadata_list.append( + MultiStepWorker._shallow_copy_input(seq_group)) # Record the index of the original sequence group. indices_of_original_sequence_groups.append(len(updated_seq_group_metadata_list) - 1) @@ -162,7 +158,7 @@ def _expand_execute_model_request( @staticmethod def _filter_model_output( - expanded_batch_output: SamplerOutput, + expanded_batch_outputs: List[SamplerOutput], output_indices_to_retain: List[int] ) -> List[SamplerOutput]: """ @@ -172,7 +168,7 @@ def _filter_model_output( the outputs of only those sequences indicated by the provided indices. Args: - expanded_batch_output (SamplerOutput): The expanded output batch + expanded_batch_output (List[SamplerOutput]): The expanded output batch from the model. output_indices_to_retain (List[int]): Indices of the model outputs to retain. @@ -196,7 +192,8 @@ def _filter_model_output( expanded_batch_output.sampled_token_ids[output_indices_to_retain] if expanded_batch_output.sampled_token_ids is not None else None ) - ) + ) + for expanded_batch_output in expanded_batch_outputs ] def get_spec_proposals( @@ -235,9 +232,8 @@ def _append_new_tokens( seq.update_num_computed_tokens(1) @staticmethod - def _shallow_copy_inputs( - seq_group_metadata_list: List[SequenceGroupMetadata] - ) -> List[SequenceGroupMetadata]: + def _shallow_copy_input(seq_group_metadata: SequenceGroupMetadata + ) -> SequenceGroupMetadata: """Copy input data structures to remove side-effects when input data structures are shared with other modules. @@ -248,23 +244,19 @@ def _shallow_copy_inputs( # Shallow-copy the list of SequenceGroupMetadata. This allows us to # append tokens and change is_prompt without external side-effects. - new_seq_group_metadata_list: List[SequenceGroupMetadata] = [] - - for old_seq_group_metadata in seq_group_metadata_list: - # We must shallow-copy seq_group_metadata as is_prompt could change. - seq_group_metadata = copy.copy(old_seq_group_metadata) - new_seq_group_metadata_list.append(seq_group_metadata) + # We must shallow-copy seq_group_metadata as is_prompt could change. + new_seq_group_metadata = copy.copy(seq_group_metadata) - # We must shallow-copy seq_data as we will append token ids - new_seq_data: Dict[int, SequenceData] = {} - for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): - new_seq_data[seq_id] = copy.copy(old_seq_data) - new_seq_data[ - seq_id].output_token_ids = old_seq_data.output_token_ids[:] + # We must shallow-copy seq_data as we will append token ids + new_seq_data: Dict[int, SequenceData] = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + new_seq_data[seq_id] = copy.copy(old_seq_data) + new_seq_data[ + seq_id].output_token_ids = old_seq_data.output_token_ids[:] - seq_group_metadata.seq_data = new_seq_data + new_seq_group_metadata.seq_data = new_seq_data - return new_seq_group_metadata_list + return new_seq_group_metadata @staticmethod def _shallow_copy_sequence_group_metadata(