From 6f1df80436c46175e09f660a99075a5eba3a2273 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Wed, 1 May 2024 21:45:42 +0900 Subject: [PATCH] [Test] Add ignore_eos test (#4519) --- tests/samplers/test_ignore_eos.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 tests/samplers/test_ignore_eos.py diff --git a/tests/samplers/test_ignore_eos.py b/tests/samplers/test_ignore_eos.py new file mode 100644 index 0000000000000..864657a3c2b28 --- /dev/null +++ b/tests/samplers/test_ignore_eos.py @@ -0,0 +1,31 @@ +"""Make sure ignore_eos works. + +Run `pytest tests/samplers/test_ignore_eos.py`. +""" + +import pytest + +from vllm import SamplingParams + +MODELS = ["facebook/opt-125m"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [1024]) +def test_beam_search_single_input( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + example_prompts = "1 + 1 is" + + vllm_model = vllm_runner(model, dtype=dtype) + sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) + ignore_eos_output = vllm_model.model.generate( + example_prompts, sampling_params=sampling_params) + print(len(ignore_eos_output[0].outputs[0].token_ids)) + assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) < 10 + assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) >= 0