Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions tests/tool_parsers/test_openai_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DeveloperContent,
HarmonyEncodingName,
Message,
RenderConversationConfig,
Role,
SystemContent,
load_harmony_encoding,
Expand Down Expand Up @@ -261,3 +262,91 @@ def test_extract_tool_calls_with_content(
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content == final_content


def test_extract_partial_response_no_tools(openai_tool_parser, harmony_encoding):
"""Test partial response without tool calls where final_content is cut off."""
final_content = "This is a partial response."
convo = Conversation.from_messages(
[
Message.from_role_and_content(
Role.USER, "What is the weather in Tokyo based on where I'm at?"
),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(Role.ASSISTANT, final_content).with_channel(
"final"
),
]
)
token_ids = harmony_encoding.render_conversation_for_completion(
convo, Role.ASSISTANT, config=RenderConversationConfig(auto_drop_analysis=False)
)
token_ids = token_ids[:-5] # Simulate cut-off by removing last 5 tokens
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert not extracted_info.tools_called
assert extracted_info.tool_calls == []
assert extracted_info.content
assert len(extracted_info.content) > 0

print(extracted_info.content)
assert len(extracted_info.content) < len(final_content)
assert extracted_info.content == final_content[: len(extracted_info.content)]


def test_extract_partial_response_with_tool_call(
openai_tool_parser,
harmony_encoding,
):
"""Test partial response with tool call where final_content is cut off."""
final_content = "Let me check the weather."
convo = Conversation.from_messages(
[
Message.from_role_and_content(
Role.USER, "What is the weather in Tokyo based on where I'm at?"
),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
.with_channel("commentary")
.with_recipient("functions.get_current_weather")
.with_content_type("json"),
Message.from_role_and_content(Role.ASSISTANT, final_content).with_channel(
"final"
),
]
)
token_ids = harmony_encoding.render_conversation_for_completion(
convo, Role.ASSISTANT, config=RenderConversationConfig(auto_drop_analysis=False)
)
token_ids = token_ids[:-5] # Simulate cut-off by removing last 5 tokens

extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert extracted_info.tools_called
expected_tool_calls = [
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
)
),
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content
assert len(extracted_info.content) > 0

print(extracted_info.content)
assert len(extracted_info.content) < len(final_content)
assert extracted_info.content == final_content[: len(extracted_info.content)]
10 changes: 10 additions & 0 deletions vllm/tool_parsers/openai_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ def extract_tool_calls(
elif msg.channel == "commentary" and not msg.recipient:
commentary_content = msg_text

# Check for partial responses:
# current content in final channel without recipient and final content.
if (
parser.current_content
and final_content is None
and parser.current_recipient is None
and parser.current_channel in [None, "final"]
):
final_content = parser.current_content

return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
Expand Down