diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 740a227efcad..60dfe33f2918 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -282,6 +282,6 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, if print_tokens: print(f'{i=} {baseline_tokens=}') print(f'{i=} {spec_tokens=}') - #print(f'{i=} {baseline_token_ids=}') - #print(f'{i=} {spec_token_ids=}') + print(f'{i=} {baseline_token_ids=}') + print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 5f870d6b4e23..94cc36f22875 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -643,12 +643,12 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, # Try a range of common k. for k in [1, 2, 3] ]) -@pytest.mark.parametrize("batch_size", [1, 8, 64]) +@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize( "output_len", [ # Use smaller output len for fast test. - 256, + 32, ]) @pytest.mark.parametrize("seed", [1]) def test_typical_acceptance_sampling(baseline_llm_generator, diff --git a/vllm/config.py b/vllm/config.py index 285248048c77..d7a037fef0c2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -885,6 +885,7 @@ def maybe_create_spec_config( typical_acceptance_sampler_posterior_alpha (Optional[float]): A scaling factor for the entropy-based threshold in the TypicalAcceptanceSampler. + Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if the necessary conditions are met, else None. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4d8306ca078..d4044adfce61 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -578,7 +578,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.ngram_prompt_lookup_min, help='Min size of window for ngram prompt lookup in speculative ' 'decoding.') - + parser.add_argument( '--spec-decoding-acceptance-method', type=str,