Skip to content

Commit

Permalink
enable bonus tokens always
Browse files Browse the repository at this point in the history
  • Loading branch information
sroy745 committed Jul 2, 2024
1 parent f7f3fd7 commit bcadab2
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 64 deletions.
4 changes: 2 additions & 2 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,6 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
if print_tokens:
print(f'{i=} {baseline_tokens=}')
print(f'{i=} {spec_tokens=}')
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
#print(f'{i=} {baseline_token_ids=}')
#print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids
4 changes: 2 additions & 2 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,12 +643,12 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
# Try a range of common k.
for k in [1, 2, 3]
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("batch_size", [1, 8, 64])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
256,
])
@pytest.mark.parametrize("seed", [1])
def test_typical_acceptance_sampling(baseline_llm_generator,
Expand Down
13 changes: 0 additions & 13 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,6 @@ def maybe_create_spec_config(
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
disable_bonus_tokens_in_kv_cache: bool,
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: Optional[float],
typical_acceptance_sampler_posterior_alpha: Optional[float],
Expand Down Expand Up @@ -873,10 +872,6 @@ def maybe_create_spec_config(
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided.
disable_bonus_tokens_in_kv_cache (bool): A boolean flag to control
the use of bonus tokens during speculative decoding in models that
rely on KV cache. If set to True, bonus tokens will be disabled
and if set to False, bonus tokens will be enabled.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
Expand Down Expand Up @@ -1015,7 +1010,6 @@ def maybe_create_spec_config(
speculative_disable_by_batch_size,
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
disable_bonus_tokens_in_kv_cache,
draft_token_acceptance_method=draft_token_acceptance_method,
typical_acceptance_sampler_posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
Expand Down Expand Up @@ -1102,7 +1096,6 @@ def __init__(
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
disable_bonus_tokens_in_kv_cache: bool,
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
Expand All @@ -1119,10 +1112,6 @@ def __init__(
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
disable_bonus_tokens_in_kv_cache: A boolean flag to control the
use of bonus tokens during speculative decoding in models that
rely on KV cache. If set to True, bonus tokens will be disabled
and if set to False, bonus tokens will be enabled.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
Expand All @@ -1144,8 +1133,6 @@ def __init__(
speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
self.disable_bonus_tokens_in_kv_cache = \
disable_bonus_tokens_in_kv_cache
self.draft_token_acceptance_method = draft_token_acceptance_method
self.typical_acceptance_sampler_posterior_threshold = \
typical_acceptance_sampler_posterior_threshold
Expand Down
11 changes: 0 additions & 11 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class EngineArgs:
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
disable_bonus_tokens_in_kv_cache: bool = True
spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
Expand Down Expand Up @@ -579,15 +578,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.ngram_prompt_lookup_min,
help='Min size of window for ngram prompt lookup in speculative '
'decoding.')

parser.add_argument(
'--disable-bonus-tokens-in-kv-cache',
type=int,
default=EngineArgs.disable_bonus_tokens_in_kv_cache,
help='A boolean flag to control the use of bonus tokens during '
'speculative decoding in models that rely on KV cache. If set '
'to True, bonus tokens will be disabled and if set to False, '
'bonus tokens will be enabled.')

parser.add_argument(
'--spec-decoding-acceptance-method',
Expand Down Expand Up @@ -781,7 +771,6 @@ def create_engine_config(self, ) -> EngineConfig:
use_v2_block_manager=self.use_v2_block_manager,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
disable_bonus_tokens_in_kv_cache=self.disable_bonus_tokens_in_kv_cache,
draft_token_acceptance_method=\
self.spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=self.
Expand Down
64 changes: 28 additions & 36 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,6 @@ def sampler_output(
"""
self._raise_if_unsupported(execute_model_req)

# Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects.
copied_seq_group_metadata_list = self._shallow_copy_inputs(
execute_model_req.seq_group_metadata_list)
copied_execute_model_req = execute_model_req.clone(
copied_seq_group_metadata_list)

# Assert enough KV space for sample_len tokens per sequence.
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
sample_len)
Expand All @@ -79,11 +72,11 @@ def sampler_output(
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
copied_execute_model_req, seq_ids_with_bonus_token_in_last_step)
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if isinstance(self.model_runner, TP1DraftModelRunner):
copied_execute_model_req.num_steps = sample_len
execute_model_req.num_steps = sample_len
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
Expand All @@ -95,10 +88,14 @@ def sampler_output(
), "composing multistep workers not supported"
model_output = model_output[0]

self._append_new_tokens(model_output,
copied_seq_group_metadata_list)
self._append_new_tokens(
model_output,
expanded_request.seq_group_metadata_list)
model_outputs.append(model_output)
return model_outputs, True

filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True

@staticmethod
def _expand_execute_model_request(
Expand Down Expand Up @@ -134,9 +131,7 @@ def _expand_execute_model_request(
#Create new sequences without the last bonus token. These new
# sequence have the same sequence id as the original sequence.
# We create a new sequence group and add them there.
updated_seq_group_without_bonus_token = \
MultiStepWorker._shallow_copy_sequence_group_metadata(
seq_group)
updated_seq_group_without_bonus_token = copy.copy(seq_group)
seq_group_without_bonus_token_data = {
seq_id: SequenceData(
prompt_token_ids=seq_group.seq_data[seq_id].prompt_token_ids,
Expand All @@ -153,7 +148,8 @@ def _expand_execute_model_request(
updated_seq_group_without_bonus_token.seq_data = seq_group_without_bonus_token_data
updated_seq_group_metadata_list.append(updated_seq_group_without_bonus_token)
# Add the original sequence group.
updated_seq_group_metadata_list.append(seq_group)
updated_seq_group_metadata_list.append(
MultiStepWorker._shallow_copy_input(seq_group))
# Record the index of the original sequence group.
indices_of_original_sequence_groups.append(len(updated_seq_group_metadata_list) - 1)

Expand All @@ -162,7 +158,7 @@ def _expand_execute_model_request(

@staticmethod
def _filter_model_output(
expanded_batch_output: SamplerOutput,
expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: List[int]
) -> List[SamplerOutput]:
"""
Expand All @@ -172,7 +168,7 @@ def _filter_model_output(
the outputs of only those sequences indicated by the provided indices.
Args:
expanded_batch_output (SamplerOutput): The expanded output batch
expanded_batch_output (List[SamplerOutput]): The expanded output batch
from the model.
output_indices_to_retain (List[int]): Indices of the model outputs to
retain.
Expand All @@ -196,7 +192,8 @@ def _filter_model_output(
expanded_batch_output.sampled_token_ids[output_indices_to_retain]
if expanded_batch_output.sampled_token_ids is not None else None
)
)
)
for expanded_batch_output in expanded_batch_outputs
]

def get_spec_proposals(
Expand Down Expand Up @@ -235,9 +232,8 @@ def _append_new_tokens(
seq.update_num_computed_tokens(1)

@staticmethod
def _shallow_copy_inputs(
seq_group_metadata_list: List[SequenceGroupMetadata]
) -> List[SequenceGroupMetadata]:
def _shallow_copy_input(seq_group_metadata: SequenceGroupMetadata
) -> SequenceGroupMetadata:
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
Expand All @@ -248,23 +244,19 @@ def _shallow_copy_inputs(

# Shallow-copy the list of SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
new_seq_group_metadata_list: List[SequenceGroupMetadata] = []

for old_seq_group_metadata in seq_group_metadata_list:
# We must shallow-copy seq_group_metadata as is_prompt could change.
seq_group_metadata = copy.copy(old_seq_group_metadata)
new_seq_group_metadata_list.append(seq_group_metadata)
# We must shallow-copy seq_group_metadata as is_prompt could change.
new_seq_group_metadata = copy.copy(seq_group_metadata)

# We must shallow-copy seq_data as we will append token ids
new_seq_data: Dict[int, SequenceData] = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
# We must shallow-copy seq_data as we will append token ids
new_seq_data: Dict[int, SequenceData] = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[
seq_id].output_token_ids = old_seq_data.output_token_ids[:]

seq_group_metadata.seq_data = new_seq_data
new_seq_group_metadata.seq_data = new_seq_data

return new_seq_group_metadata_list
return new_seq_group_metadata

@staticmethod
def _shallow_copy_sequence_group_metadata(
Expand Down

0 comments on commit bcadab2

Please sign in to comment.