Skip to content

Commit

Permalink
[Bugfix] Support logprobs when using guided_json and other constraine…
Browse files Browse the repository at this point in the history
…d decoding fields (vllm-project#4149)
  • Loading branch information
jamestwhedbee authored Apr 18, 2024
1 parent 705578a commit e1bb2fd
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
30 changes: 30 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,36 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))


@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
"The best language for type-safe systems programming is "
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
top_logprobs = chat_completion.choices[0].logprobs.top_logprobs

# -9999.0 is the minimum logprob returned by OpenAI
assert all(
isinstance(logprob, float) and logprob >= -9999.0
for token_dict in top_logprobs
for token, logprob in token_dict.items())


async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,
Expand Down
4 changes: 3 additions & 1 deletion vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def _create_logprobs(

if num_output_top_logprobs:
logprobs.top_logprobs.append({
p.decoded_token: p.logprob
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
p.decoded_token: max(p.logprob, -9999.0)
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)

Expand Down

0 comments on commit e1bb2fd

Please sign in to comment.