From e1bb2fd52dea0bbc772bdf35fd27664c5daec7b2 Mon Sep 17 00:00:00 2001 From: James Whedbee Date: Thu, 18 Apr 2024 16:12:55 -0500 Subject: [PATCH] [Bugfix] Support logprobs when using guided_json and other constrained decoding fields (#4149) --- tests/entrypoints/test_openai_server.py | 30 +++++++++++++++++++++++ vllm/entrypoints/openai/serving_engine.py | 4 ++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 14e6ee0ffe9d9..0dd30eec30086 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -723,6 +723,36 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + "The best language for type-safe systems programming is " + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5, + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) + top_logprobs = chat_completion.choices[0].logprobs.top_logprobs + + # -9999.0 is the minimum logprob returned by OpenAI + assert all( + isinstance(logprob, float) and logprob >= -9999.0 + for token_dict in top_logprobs + for token, logprob in token_dict.items()) + + async def test_response_format_json_object(server, client: openai.AsyncOpenAI): resp = await client.chat.completions.create( model=MODEL_NAME, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index b5a7a977ebbab..8e5ee88d7f3a9 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -116,7 +116,9 @@ def _create_logprobs( if num_output_top_logprobs: logprobs.top_logprobs.append({ - p.decoded_token: p.logprob + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + p.decoded_token: max(p.logprob, -9999.0) for i, p in step_top_logprobs.items() } if step_top_logprobs else None)