diff --git a/tests/entrypoints/openai/test_serving_chat_stream_harmony.py b/tests/entrypoints/openai/test_serving_chat_stream_harmony.py new file mode 100644 index 000000000000..1934d43d5cfb --- /dev/null +++ b/tests/entrypoints/openai/test_serving_chat_stream_harmony.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for harmony streaming delta extraction. +""" + +from dataclasses import dataclass, field +from unittest.mock import patch + +import pytest + +from vllm.entrypoints.openai.serving_chat_stream_harmony import ( + extract_harmony_streaming_delta, +) + + +@dataclass +class MockMessage: + """Mock message object for testing.""" + + channel: str | None = None + recipient: str | None = None + + +@dataclass +class MockStreamableParser: + """Mock StreamableParser for testing without openai_harmony dependency.""" + + messages: list[MockMessage] = field(default_factory=list) + + +class TestExtractHarmonyStreamingDelta: + """Tests for extract_harmony_streaming_delta function.""" + + @pytest.mark.parametrize( + "delta_text,expected_content", + [ + ("Hello, world!", "Hello, world!"), + ("", ""), + ], + ) + def test_final_channel_returns_content_delta(self, delta_text, expected_content): + """Test that final channel returns a DeltaMessage with content.""" + parser = MockStreamableParser() + delta_message, tools_streamed = extract_harmony_streaming_delta( + harmony_parser=parser, + cur_channel="final", + cur_recipient=None, + prev_recipient=None, + delta_text=delta_text, + include_reasoning=False, + ) + + assert delta_message is not None + assert delta_message.content == expected_content + assert tools_streamed is False + + @pytest.mark.parametrize( + "include_reasoning,expected_has_message", + [ + (True, True), + (False, False), + ], + ) + def test_analysis_channel_reasoning(self, include_reasoning, expected_has_message): + """Test analysis channel respects include_reasoning flag.""" + parser = MockStreamableParser() + delta_message, tools_streamed = extract_harmony_streaming_delta( + harmony_parser=parser, + cur_channel="analysis", + cur_recipient=None, + prev_recipient=None, + delta_text="Let me think...", + include_reasoning=include_reasoning, + ) + + if expected_has_message: + assert delta_message is not None + assert delta_message.reasoning == "Let me think..." + else: + assert delta_message is None + assert tools_streamed is False + + @pytest.mark.parametrize("channel", ["commentary", "analysis"]) + @patch("vllm.entrypoints.openai.serving_chat_stream_harmony.make_tool_call_id") + def test_new_tool_call(self, mock_make_tool_call_id, channel): + """Test new tool call creation when recipient changes.""" + mock_make_tool_call_id.return_value = "call_test123" + parser = MockStreamableParser() + + delta_message, tools_streamed = extract_harmony_streaming_delta( + harmony_parser=parser, + cur_channel=channel, + cur_recipient="functions.get_weather", + prev_recipient=None, + delta_text="", + include_reasoning=False, + ) + + assert delta_message is not None + assert len(delta_message.tool_calls) == 1 + tool_call = delta_message.tool_calls[0] + assert tool_call.id == "call_test123" + assert tool_call.type == "function" + assert tool_call.function.name == "get_weather" + assert tool_call.function.arguments == "" + assert tool_call.index == 0 + assert tools_streamed is True + + @pytest.mark.parametrize("channel", ["commentary", "analysis"]) + def test_tool_call_argument_streaming(self, channel): + """Test streaming tool call arguments (same recipient).""" + parser = MockStreamableParser() + + delta_message, tools_streamed = extract_harmony_streaming_delta( + harmony_parser=parser, + cur_channel=channel, + cur_recipient="functions.get_weather", + prev_recipient="functions.get_weather", + delta_text='{"location": "Paris"}', + include_reasoning=False, + ) + + assert delta_message is not None + tool_call = delta_message.tool_calls[0] + assert tool_call.id is None + assert tool_call.function.arguments == '{"location": "Paris"}' + assert tool_call.index == 0 + assert tools_streamed is True + + @pytest.mark.parametrize("channel", ["commentary", "analysis"]) + def test_tool_call_empty_arguments_returns_none(self, channel): + """Test empty delta_text with same recipient returns None.""" + parser = MockStreamableParser() + + delta_message, tools_streamed = extract_harmony_streaming_delta( + harmony_parser=parser, + cur_channel=channel, + cur_recipient="functions.get_weather", + prev_recipient="functions.get_weather", + delta_text="", + include_reasoning=False, + ) + + assert delta_message is None + assert tools_streamed is False + + def test_tool_call_index_from_previous_messages(self): + """Test tool call index accounts for previous function messages.""" + messages = [ + MockMessage(channel="analysis", recipient=None), # Not counted + MockMessage(channel="commentary", recipient="functions.tool1"), # Counted + MockMessage(channel="final", recipient=None), # Not counted + ] + parser = MockStreamableParser(messages=messages) + + delta_message, _ = extract_harmony_streaming_delta( + harmony_parser=parser, + cur_channel="commentary", + cur_recipient="functions.tool2", + prev_recipient="functions.tool2", + delta_text="args", + include_reasoning=False, + ) + + assert delta_message.tool_calls[0].index == 1 + + @pytest.mark.parametrize( + "channel,recipient", + [ + ("commentary", None), + ("commentary", "browser.search"), + ], + ) + def test_returns_tool_call_preambles(self, channel, recipient): + """Test that invalid channel/recipient combinations return None.""" + parser = MockStreamableParser() + delta_text = "some text" + delta_message, tools_streamed = extract_harmony_streaming_delta( + harmony_parser=parser, + cur_channel=channel, + cur_recipient=recipient, + prev_recipient=None, + delta_text=delta_text, + include_reasoning=True, + ) + + assert delta_message.content == delta_text + assert tools_streamed is False + + @pytest.mark.parametrize( + "channel,recipient", + [ + (None, None), + ("unknown_channel", None), + ], + ) + def test_returns_none_for_invalid_inputs(self, channel, recipient): + """Test that invalid channel/recipient combinations return None.""" + parser = MockStreamableParser() + + delta_message, tools_streamed = extract_harmony_streaming_delta( + harmony_parser=parser, + cur_channel=channel, + cur_recipient=recipient, + prev_recipient=None, + delta_text="some text", + include_reasoning=True, + ) + + assert delta_message is None + assert tools_streamed is False diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e36ae00fc9b8..40258c5c929a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -51,6 +51,9 @@ ToolCall, UsageInfo, ) +from vllm.entrypoints.openai.serving_chat_stream_harmony import ( + extract_harmony_streaming_delta, +) from vllm.entrypoints.openai.serving_engine import ( GenerationError, OpenAIServing, @@ -824,64 +827,17 @@ async def chat_completion_stream_generator( current_token_ids = as_list(output.token_ids) if self.use_harmony: - if cur_channel == "final": - delta_message = DeltaMessage(content=delta_text) - elif cur_channel == "analysis": - if request.include_reasoning: - delta_message = DeltaMessage(reasoning=delta_text) - else: - delta_message = None - elif ( - cur_channel == "commentary" - and cur_recipient - and cur_recipient.startswith("functions.") - ): - # Count completed tool calls to determine index - base_index = 0 - for msg in harmony_parser.messages: - if ( - msg.channel == "commentary" - and msg.recipient - and msg.recipient.startswith("functions.") - ): - base_index += 1 - - if prev_recipient != cur_recipient: - tool_name = cur_recipient.split("functions.", 1)[1] - delta_message = DeltaMessage( - tool_calls=[ - DeltaToolCall( - id=make_tool_call_id(), - type="function", - function=DeltaFunctionCall( - name=tool_name, - arguments="", - ), - index=base_index, - ) - ] - ) - elif delta_text: - delta_message = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=base_index, - function=DeltaFunctionCall( - arguments=delta_text - ), - ) - ] - ) - else: - delta_message = None - - if delta_message is not None: - harmony_tools_streamed[i] = True - elif cur_channel == "commentary": - # Tool call preambles meant to be shown to the user - delta_message = DeltaMessage(content=delta_text) - else: - delta_message = None + delta_message, tools_streamed_flag = ( + extract_harmony_streaming_delta( + harmony_parser=harmony_parser, + cur_channel=cur_channel, + cur_recipient=cur_recipient, + prev_recipient=prev_recipient, + delta_text=delta_text, + include_reasoning=request.include_reasoning, + ) + ) + harmony_tools_streamed[i] |= tools_streamed_flag # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: if ( diff --git a/vllm/entrypoints/openai/serving_chat_stream_harmony.py b/vllm/entrypoints/openai/serving_chat_stream_harmony.py new file mode 100644 index 000000000000..1b5ae620651c --- /dev/null +++ b/vllm/entrypoints/openai/serving_chat_stream_harmony.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Harmony-specific streaming delta extraction for chat completions. + +This module handles the extraction of DeltaMessage objects from +harmony parser state during streaming chat completions. +""" + +from openai_harmony import StreamableParser + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, +) + + +def extract_harmony_streaming_delta( + harmony_parser: StreamableParser, + cur_channel: str | None, + cur_recipient: str | None, + prev_recipient: str | None, + delta_text: str, + include_reasoning: bool, +) -> tuple[DeltaMessage | None, bool]: + """ + Extract a DeltaMessage from harmony parser state during streaming. + + Args: + harmony_parser: The StreamableParser instance tracking parse state + cur_channel: Current channel ("final", "analysis", "commentary", etc.) + cur_recipient: Current recipient (e.g., "functions.my_func") + prev_recipient: Previous recipient for detecting tool call transitions + delta_text: The text delta to include in the message + include_reasoning: Whether to include reasoning content + + Returns: + A tuple of (DeltaMessage or None, tools_streamed_flag) + """ + tools_streamed = False + + if cur_channel == "final": + delta_message = DeltaMessage(content=delta_text) + elif ( + (cur_channel == "commentary" or cur_channel == "analysis") + and cur_recipient + and cur_recipient.startswith("functions.") + ): + # Count completed tool calls to determine index + base_index = 0 + for msg in harmony_parser.messages: + if ( + (msg.channel == "commentary" or msg.channel == "analysis") + and msg.recipient + and msg.recipient.startswith("functions.") + ): + base_index += 1 + + if prev_recipient != cur_recipient: + tool_name = cur_recipient.split("functions.", 1)[1] + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + id=make_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_name, + arguments="", + ), + index=base_index, + ) + ] + ) + elif delta_text: + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=base_index, + function=DeltaFunctionCall(arguments=delta_text), + ) + ] + ) + else: + delta_message = None + + if delta_message is not None: + tools_streamed = True + elif cur_channel == "commentary": + # Tool call preambles meant to be shown to the user + delta_message = DeltaMessage(content=delta_text) + elif cur_channel == "analysis": + if include_reasoning: + delta_message = DeltaMessage(reasoning=delta_text) + else: + delta_message = None + else: + delta_message = None + + return delta_message, tools_streamed