Skip to content

Commit 7c57254

Browse files
Hanchenligemini-code-assist[bot]frank-weiDarkLight1337
authored
[GPT-OSS] Structure_Tag support for gpt-oss tool-call in cot (vllm-project#25515)
Signed-off-by: Hanchenli <[email protected]> Signed-off-by: Hanchenli <[email protected]> Signed-off-by: Wei Wei <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Wei Wei <[email protected]> Co-authored-by: Wei Wei <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent c312320 commit 7c57254

File tree

14 files changed

+911
-32
lines changed

14 files changed

+911
-32
lines changed
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
"""Integration tests for GPT-OSS structural tags functionality (PR #25515)."""
5+
6+
import json
7+
from unittest.mock import Mock
8+
9+
import pytest
10+
11+
from vllm.entrypoints.openai.protocol import (
12+
StructuredOutputsParams,
13+
)
14+
from vllm.entrypoints.tool_server import ToolServer
15+
from vllm.reasoning.gptoss_reasoning_parser import (
16+
GptOssReasoningParser,
17+
)
18+
19+
20+
class TestGptOssStructuralTagsIntegration:
21+
"""Integration tests for structural tags in GPT-OSS tool calls."""
22+
23+
@pytest.fixture
24+
def mock_tokenizer(self):
25+
"""Create a mock tokenizer."""
26+
tokenizer = Mock()
27+
tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
28+
return tokenizer
29+
30+
@pytest.fixture
31+
def gptoss_parser(self, mock_tokenizer):
32+
"""Create a real GptOssReasoningParser instance."""
33+
return GptOssReasoningParser(mock_tokenizer)
34+
35+
@pytest.fixture
36+
def tool_server_with_python(self):
37+
"""Create a tool server with Python tool enabled."""
38+
tool_server = Mock(spec=ToolServer)
39+
tool_server.has_tool = Mock(side_effect=lambda tool: tool == "python")
40+
return tool_server
41+
42+
@pytest.fixture
43+
def tool_server_empty(self):
44+
"""Create a tool server with no tools."""
45+
tool_server = Mock(spec=ToolServer)
46+
tool_server.has_tool = Mock(return_value=False)
47+
return tool_server
48+
49+
def test_end_to_end_no_tools(self, gptoss_parser):
50+
"""Test end-to-end flow when no tools are available."""
51+
# Test the parser directly
52+
result = gptoss_parser.prepare_structured_tag(None, None)
53+
parsed_result = json.loads(result)
54+
55+
# Verify basic structure
56+
assert parsed_result["type"] == "structural_tag"
57+
assert parsed_result["format"]["type"] == "triggered_tags"
58+
assert len(parsed_result["format"]["tags"]) == 1
59+
60+
# Verify only analysis channel is allowed
61+
analysis_tag = parsed_result["format"]["tags"][0]
62+
assert analysis_tag["begin"] == "<|channel|>analysis<|message|>"
63+
assert analysis_tag["content"]["type"] == "any_text"
64+
assert analysis_tag["end"] == "<|end|>"
65+
66+
# Verify triggers
67+
assert parsed_result["format"]["triggers"] == ["<|channel|>analysis"]
68+
assert parsed_result["format"]["stop_after_first"] is False
69+
70+
def test_end_to_end_with_python_tool(self, gptoss_parser, tool_server_with_python):
71+
"""Test end-to-end flow with Python tool enabled."""
72+
result = gptoss_parser.prepare_structured_tag(None, tool_server_with_python)
73+
parsed_result = json.loads(result)
74+
75+
# Should have analysis tag + 2 python tags
76+
assert len(parsed_result["format"]["tags"]) == 3
77+
78+
# Verify all expected tags are present
79+
tag_begins = [tag["begin"] for tag in parsed_result["format"]["tags"]]
80+
expected_begins = [
81+
"<|channel|>analysis<|message|>",
82+
"<|channel|>commentary to=python",
83+
"<|channel|>analysis to=python",
84+
]
85+
86+
for expected in expected_begins:
87+
assert expected in tag_begins
88+
89+
# Verify triggers include commentary
90+
assert "<|channel|>analysis" in parsed_result["format"]["triggers"]
91+
assert "<|channel|>commentary to=" in parsed_result["format"]["triggers"]
92+
93+
def test_structured_outputs_params_integration(
94+
self, gptoss_parser, tool_server_with_python
95+
):
96+
"""Test integration with StructuredOutputsParams."""
97+
# Generate structural tag
98+
structural_tag = gptoss_parser.prepare_structured_tag(
99+
None, tool_server_with_python
100+
)
101+
102+
# Create StructuredOutputsParams
103+
params = StructuredOutputsParams(structural_tag=structural_tag)
104+
105+
# Verify the tag is properly stored and accessible
106+
assert params.structural_tag == structural_tag
107+
108+
# Verify the tag is valid JSON
109+
parsed_tag = json.loads(params.structural_tag)
110+
assert parsed_tag["type"] == "structural_tag"
111+
112+
@pytest.mark.parametrize(
113+
"browser, python, container, expected_tags",
114+
[
115+
# No tools
116+
(False, False, False, 1),
117+
# Single tool
118+
(True, False, False, 3),
119+
# Multiple tools
120+
(True, True, False, 5),
121+
# All tools
122+
(True, True, True, 7),
123+
],
124+
)
125+
def test_tool_server_interaction_flow(
126+
self, gptoss_parser, browser, python, container, expected_tags
127+
):
128+
"""Test the complete tool server interaction flow."""
129+
130+
# Create a mock ToolServer
131+
tool_server = Mock(spec=ToolServer)
132+
133+
# Simulate tool availability based on parameters
134+
tool_server.has_tool = Mock(
135+
side_effect=lambda tool: {
136+
"browser": browser,
137+
"python": python,
138+
"container": container,
139+
}.get(tool, False)
140+
)
141+
142+
# Run the parser and verify results
143+
result = gptoss_parser.prepare_structured_tag(None, tool_server)
144+
parsed_result = json.loads(result)
145+
146+
# Validate number of tags
147+
assert len(parsed_result["format"]["tags"]) == expected_tags
148+
149+
# Verify tool-specific tags exist for enabled tools
150+
tag_begins = [tag["begin"] for tag in parsed_result["format"]["tags"]]
151+
for tool, enabled in {
152+
"browser": browser,
153+
"python": python,
154+
"container": container,
155+
}.items():
156+
if enabled:
157+
assert f"<|channel|>commentary to={tool}" in tag_begins
158+
assert f"<|channel|>analysis to={tool}" in tag_begins
159+
160+
def test_original_tag_preservation(self, gptoss_parser, tool_server_with_python):
161+
"""Test that original tags are preserved when provided."""
162+
original_tag = '{"type": "custom_tag", "data": "preserved"}'
163+
164+
result = gptoss_parser.prepare_structured_tag(
165+
original_tag, tool_server_with_python
166+
)
167+
168+
# Should return original tag unchanged
169+
assert result == original_tag
170+
171+
@pytest.mark.parametrize(
172+
"tools",
173+
[
174+
[],
175+
["browser"],
176+
["python"],
177+
["container"],
178+
["browser", "python"],
179+
["browser", "container"],
180+
["python", "container"],
181+
["browser", "python", "container"],
182+
],
183+
)
184+
def test_json_validity_comprehensive(self, gptoss_parser, tools):
185+
"""Test JSON validity across all possible tool combinations."""
186+
187+
tool_server = Mock(spec=ToolServer)
188+
tool_server.has_tool = Mock(side_effect=lambda tool: tool in tools)
189+
190+
result = gptoss_parser.prepare_structured_tag(None, tool_server)
191+
192+
# Should be valid JSON
193+
parsed_result = json.loads(result)
194+
195+
# Should have correct structure
196+
assert parsed_result["type"] == "structural_tag"
197+
assert "format" in parsed_result
198+
assert "tags" in parsed_result["format"]
199+
assert "triggers" in parsed_result["format"]
200+
201+
# Tag count should be: 1 (analysis) + 2 * len(tools)
202+
expected_tag_count = 1 + (2 * len(tools))
203+
assert len(parsed_result["format"]["tags"]) == expected_tag_count
204+
205+
def test_error_handling_invalid_tool_server(self, gptoss_parser):
206+
"""Test error handling with invalid tool server."""
207+
# Tool server that raises exceptions
208+
tool_server = Mock(spec=ToolServer)
209+
tool_server.has_tool = Mock(side_effect=Exception("Tool server error"))
210+
211+
# Should handle gracefully and still return a valid tag
212+
with pytest.raises(Exception, match="Tool server error"):
213+
gptoss_parser.prepare_structured_tag(None, tool_server)
214+
215+
def test_concurrent_requests_isolation(self, gptoss_parser):
216+
"""Test that concurrent requests don't interfere with each other."""
217+
# Simulate concurrent requests with different tool servers
218+
tool_server_1 = Mock(spec=ToolServer)
219+
tool_server_1.has_tool = Mock(side_effect=lambda tool: tool == "python")
220+
221+
tool_server_2 = Mock(spec=ToolServer)
222+
tool_server_2.has_tool = Mock(side_effect=lambda tool: tool == "browser")
223+
224+
# Generate tags concurrently
225+
result_1 = gptoss_parser.prepare_structured_tag(None, tool_server_1)
226+
result_2 = gptoss_parser.prepare_structured_tag(None, tool_server_2)
227+
228+
# Parse results
229+
parsed_1 = json.loads(result_1)
230+
parsed_2 = json.loads(result_2)
231+
232+
# Verify they have different tool configurations
233+
tags_1 = [tag["begin"] for tag in parsed_1["format"]["tags"]]
234+
tags_2 = [tag["begin"] for tag in parsed_2["format"]["tags"]]
235+
236+
# Result 1 should have python tags
237+
assert "<|channel|>commentary to=python" in tags_1
238+
assert "<|channel|>commentary to=browser" not in tags_1
239+
240+
# Result 2 should have browser tags
241+
assert "<|channel|>commentary to=browser" in tags_2
242+
assert "<|channel|>commentary to=python" not in tags_2
243+
244+
def test_tag_format_consistency(self, gptoss_parser):
245+
"""Test that all generated tags follow consistent format."""
246+
tool_server = Mock(spec=ToolServer)
247+
tool_server.has_tool = Mock(
248+
side_effect=lambda tool: tool in ["python", "browser"]
249+
)
250+
251+
result = gptoss_parser.prepare_structured_tag(None, tool_server)
252+
parsed_result = json.loads(result)
253+
254+
# Verify all tags have required fields
255+
for tag in parsed_result["format"]["tags"]:
256+
assert "begin" in tag
257+
assert "content" in tag
258+
assert "end" in tag
259+
assert tag["content"]["type"] == "any_text"
260+
assert tag["end"] == "<|end|>"
261+
262+
# Verify begin format
263+
assert tag["begin"].startswith("<|channel|>")
264+
265+
def test_trigger_configuration(self, gptoss_parser):
266+
"""Test trigger configuration for different tool setups."""
267+
# Test with no tools
268+
result_no_tools = gptoss_parser.prepare_structured_tag(None, None)
269+
parsed_no_tools = json.loads(result_no_tools)
270+
assert parsed_no_tools["format"]["triggers"] == ["<|channel|>analysis"]
271+
272+
# Test with tools
273+
tool_server = Mock(spec=ToolServer)
274+
tool_server.has_tool = Mock(side_effect=lambda tool: tool == "python")
275+
276+
result_with_tools = gptoss_parser.prepare_structured_tag(None, tool_server)
277+
parsed_with_tools = json.loads(result_with_tools)
278+
279+
expected_triggers = ["<|channel|>analysis", "<|channel|>commentary to="]
280+
assert set(parsed_with_tools["format"]["triggers"]) == set(expected_triggers)

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,49 @@ def test_structured_output_batched_with_non_structured_outputs_requests(
864864
# non-structured outputs requests should not return a valid JSON here
865865
with pytest.raises(ValueError):
866866
output_json = json.loads(generated_text)
867+
868+
869+
@pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"])
870+
def test_structured_output_with_structural_tag(
871+
monkeypatch: pytest.MonkeyPatch,
872+
guided_decoding_backend: str,
873+
):
874+
monkeypatch.setenv("VLLM_USE_V1", "1")
875+
876+
llm = LLM(
877+
model="Qwen/Qwen2.5-1.5B-Instruct",
878+
guided_decoding_backend=guided_decoding_backend,
879+
)
880+
881+
structural_tag_config = {
882+
"type": "structural_tag",
883+
"format": {
884+
"type": "triggered_tags",
885+
"tags": [
886+
{"begin": "hello_flag", "content": {"type": "any_text"}, "end": "hello"}
887+
],
888+
"triggers": ["hello"],
889+
"stop_after_first": False,
890+
},
891+
}
892+
893+
sampling_params = SamplingParams(
894+
temperature=0.0,
895+
max_tokens=500,
896+
guided_decoding=StructuredOutputsParams(
897+
structural_tag=json.dumps(structural_tag_config)
898+
),
899+
)
900+
901+
prompt = "Hello and repete hello 10 times, do not say anything else. Only say hello hello hello, now start"
902+
outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True)
903+
assert outputs is not None
904+
for output in outputs:
905+
assert output is not None
906+
assert isinstance(output, RequestOutput)
907+
prompt = output.prompt
908+
generated_text = output.outputs[0].text
909+
assert generated_text is not None
910+
assert "hello_flag" in generated_text, (
911+
f"Expected 'hello_flag' to be in generated text, but got: {generated_text}"
912+
)

0 commit comments

Comments
 (0)