From 2da959604d72eced8884f76a9f173aa64f746488 Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Sat, 20 Dec 2025 12:30:44 +0800 Subject: [PATCH 1/3] [GLM-4.7] GLM-4.7 Tool Parser and Doc Update (#15333) --- docs/advanced_features/server_arguments.md | 2 +- docs/basic_usage/glm45.md | 6 +- .../srt/function_call/function_call_parser.py | 2 + .../srt/function_call/glm47_moe_detector.py | 584 ++++++++++++++++++ .../srt/function_call/glm4_moe_detector.py | 7 +- python/sglang/srt/models/glm4_moe.py | 2 +- .../test_function_call_parser.py | 275 +++++++++ 7 files changed, 872 insertions(+), 6 deletions(-) create mode 100644 python/sglang/srt/function_call/glm47_moe_detector.py diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index e90b598d932e..19c2ecb9dfe0 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -198,7 +198,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--file-storage-path` | The path of the file storage in backend. | `sglang_storage` | Type: str | | `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | `False` | bool flag (set to enable) | | `--reasoning-parser` | Specify the parser for reasoning models. Supported parsers: [deepseek-r1, deepseek-v3, glm45, gpt-oss, kimi, qwen3, qwen3-thinking, step3]. | `None` | `deepseek-r1`, `deepseek-v3`, `glm45`, `gpt-oss`, `kimi`, `qwen3`, `qwen3-thinking`, `step3` | -| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Supported parsers: [deepseekv3, deepseekv31, glm, glm45, gpt-oss, kimi_k2, llama3, mistral, pythonic, qwen, qwen25, qwen3_coder, step3]. | `None` | `deepseekv3`, `deepseekv31`, `glm`, `glm45`, `gpt-oss`, `kimi_k2`, `llama3`, `mistral`, `pythonic`, `qwen`, `qwen25`, `qwen3_coder`, `step3` | +| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Supported parsers: [deepseekv3, deepseekv31, glm, glm45, glm47, gpt-oss, kimi_k2, llama3, mistral, pythonic, qwen, qwen25, qwen3_coder, step3]. | `None` | `deepseekv3`, `deepseekv31`, `glm`, `glm45`, `glm47`, `gpt-oss`, `kimi_k2`, `llama3`, `mistral`, `pythonic`, `qwen`, `qwen25`, `qwen3_coder`, `step3` | | `--sampling-defaults` | Where to get default sampling parameters. 'openai' uses SGLang/OpenAI defaults (temperature=1.0, top_p=1.0, etc.). 'model' uses the model's generation_config.json to get the recommended sampling parameters if available. Default is 'model'. | `model` | `openai`, `model` | | `--tool-server` | Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used. | `None` | Type: str | diff --git a/docs/basic_usage/glm45.md b/docs/basic_usage/glm45.md index d18b0a68d335..b68f984f90da 100644 --- a/docs/basic_usage/glm45.md +++ b/docs/basic_usage/glm45.md @@ -1,4 +1,4 @@ -## Launch GLM-4.5 / GLM-4.6 with SGLang +## Launch GLM-4.5 / GLM-4.6 / GLM-4.7 with SGLang To serve GLM-4.5 / GLM-4.6 FP8 models on 8xH100/H200 GPUs: @@ -35,7 +35,9 @@ python3 -m sglang.launch_server \ --enable-custom-logit-processor ``` -### Thinking Budget for GLM-4.5 / GLM-4.6 +**Note**: For GLM-4.7, `--tool-call-parser` should be set to `glm47`, for GLM-4.5 and GLM-4.6, it should be set to `glm45`. + +### Thinking Budget In SGLang, we can implement thinking budget with `CustomLogitProcessor`. diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index 3700ea067af4..0592faf81832 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -15,6 +15,7 @@ from sglang.srt.function_call.deepseekv31_detector import DeepSeekV31Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector +from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector from sglang.srt.function_call.gpt_oss_detector import GptOssDetector from sglang.srt.function_call.kimik2_detector import KimiK2Detector from sglang.srt.function_call.llama32_detector import Llama32Detector @@ -45,6 +46,7 @@ class FunctionCallParser: "deepseekv32": DeepSeekV32Detector, "glm": Glm4MoeDetector, "glm45": Glm4MoeDetector, + "glm47": Glm47MoeDetector, "gpt-oss": GptOssDetector, "longcat": LongCatDetector, "kimi_k2": KimiK2Detector, diff --git a/python/sglang/srt/function_call/glm47_moe_detector.py b/python/sglang/srt/function_call/glm47_moe_detector.py new file mode 100644 index 000000000000..726737c5788a --- /dev/null +++ b/python/sglang/srt/function_call/glm47_moe_detector.py @@ -0,0 +1,584 @@ +import ast +import json +import logging +import re +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + ToolCallItem, + _GetInfoFunc, +) + +logger = logging.getLogger(__name__) + + +class StreamState(str, Enum): + """State machine states for XML to JSON streaming conversion.""" + + INIT = "INIT" + BETWEEN = "BETWEEN" + IN_KEY = "IN_KEY" + WAITING_VALUE = "WAITING_VALUE" + IN_VALUE = "IN_VALUE" + + +def get_argument_type( + func_name: str, arg_key: str, defined_tools: List[Tool] +) -> Optional[str]: + """Get the expected type of a function argument from tool definitions. + + Args: + func_name: Name of the function/tool + arg_key: Name of the argument + defined_tools: List of available tools + + Returns: + The type string (e.g., 'string', 'number', 'object') or None if not found + """ + name2tool = {tool.function.name: tool for tool in defined_tools} + if func_name not in name2tool: + return None + tool = name2tool[func_name] + properties = (tool.function.parameters or {}).get("properties", {}) + if not isinstance(properties, dict): + properties = {} + if arg_key not in properties: + return None + return properties[arg_key].get("type", None) + + +def _convert_to_number(value: str) -> Any: + """Convert string to appropriate number type (int or float). + + Args: + value: String value to convert + + Returns: + Converted number or original string if conversion fails + """ + try: + if "." in value or "e" in value.lower(): + return float(value) + else: + return int(value) + except (ValueError, AttributeError): + return value + + +def parse_arguments( + json_value: str, arg_type: Optional[str] = None +) -> Tuple[Any, bool]: + """Parse argument value with multiple fallback strategies. + + Args: + json_value: Raw string value to parse + arg_type: Expected type hint ('string', 'number', 'object', etc.) + + Returns: + Tuple of (parsed_value, is_valid_json) + """ + # Strategy 1: Direct JSON parsing + try: + parsed_value = json.loads(json_value) + + # Type coercion for number type + if arg_type == "number" and isinstance(parsed_value, str): + parsed_value = _convert_to_number(parsed_value) + + return parsed_value, True + except (json.JSONDecodeError, ValueError): + pass + + # Strategy 2: Unescape and parse + try: + wrapped = json.loads('{"tmp": "' + json_value + '"}') + parsed_value = json.loads(wrapped["tmp"]) + + if arg_type == "number" and isinstance(parsed_value, str): + parsed_value = _convert_to_number(parsed_value) + + return parsed_value, True + except (json.JSONDecodeError, ValueError, KeyError): + pass + + # Strategy 3: ast.literal_eval + try: + parsed_value = ast.literal_eval(json_value) + return parsed_value, True + except (ValueError, SyntaxError): + pass + + # Strategy 4: Treat as string + try: + quoted_value = json.dumps(str(json_value)) + return json.loads(quoted_value), True + except (json.JSONDecodeError, ValueError): + return json_value, False + + +class Glm47MoeDetector(BaseFormatDetector): + """ + Detector for GLM-4.7 and GLM-5 models. + Assumes function call format: + get_weathercity北京date2024-06-27get_weathercity上海date2024-06-27 + """ + + def __init__(self): + super().__init__() + self.bot_token = "" + self.eot_token = "" + self.func_call_regex = r".*?" + self.func_detail_regex = re.compile( + r"(.*?)(.*?)?", re.DOTALL + ) + self.func_arg_regex = re.compile( + r"(.*?)(?:\\n|\s)*(.*?)", + re.DOTALL, + ) + self._last_arguments = "" + self.current_tool_id = -1 + self.current_tool_name_sent = False + self._streamed_raw_length = 0 + self._reset_streaming_state() + + def _reset_streaming_state(self) -> None: + """Reset the streaming state machine for a new tool call.""" + self._stream_state = StreamState.INIT + self._current_key = "" + self._current_value = "" + self._xml_tag_buffer = "" + self._is_first_param = True + self._value_started = False + self._cached_value_type: Optional[str] = ( + None # Cache the value type for consistency + ) + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a glm-4.5 / glm-4.6 format tool call.""" + return self.bot_token in text + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) + match_result_list = re.findall(self.func_call_regex, text, re.DOTALL) + calls = [] + try: + for match_result in match_result_list: + # Get function name + func_detail = self.func_detail_regex.search(match_result) + func_name = func_detail.group(1) + func_args = func_detail.group(2) + arguments = {} + if func_args: + pairs = self.func_arg_regex.findall(func_args) + # Parse arguments using shared method + arguments = self._parse_argument_pairs(pairs, func_name, tools) + + # construct match_result for parse_base_json + match_result = {"name": func_name, "parameters": arguments} + calls.extend(self.parse_base_json(match_result, tools)) + return StreamingParseResult(normal_text=normal_text, calls=calls) + except Exception as e: + logger.error(f"Error in detect_and_parse: {e}", exc_info=True) + # return the normal text if parsing fails + return StreamingParseResult(normal_text=text) + + def _get_value_type(self, func_name: str, key: str, tools: List[Tool]) -> str: + """Get parameter type from tool definition, with fallback to auto-detection. + + Args: + func_name: Name of the function + key: Parameter name + tools: List of available tools + + Returns: + Type string: 'string', 'number', or 'object' + """ + arg_type = get_argument_type(func_name, key, tools) + if arg_type: + return arg_type + + # Auto-detect type from value (best effort) + first_chars = self._current_value.strip()[:10] if self._current_value else "" + if first_chars: + first_char = first_chars[0] + if first_char.isdigit() or first_char in ["-", "."]: + return "number" + elif first_char in ["{", "["]: + return "object" + + return "string" + + def _format_value_complete(self, value: str, value_type: str) -> str: + """Format complete value based on type. + + Args: + value: Raw value string + value_type: Expected type ('string', 'number', 'object') + + Returns: + Properly formatted JSON value string + """ + if value_type == "string": + # Ensure proper JSON string formatting with quotes + return json.dumps(value, ensure_ascii=False) + elif value_type == "number": + try: + num = _convert_to_number(value.strip()) + return str(num) + except (ValueError, AttributeError): + # Fallback to string if not a valid number + logger.warning( + f"Failed to parse '{value}' as number, treating as string" + ) + return json.dumps(str(value), ensure_ascii=False) + else: + # For object/array types, return as-is (should already be valid JSON) + return value + + def _process_xml_to_json_streaming( + self, raw_increment: str, func_name: str, tools: List[Tool] + ) -> str: + """Convert XML increment to JSON streaming output using state machine. + + This method processes XML fragments character by character and converts them + to JSON format incrementally. It maintains state across calls to handle + partial XML tags and values. + + Args: + raw_increment: New XML content to process + func_name: Name of the function being called + tools: List of available tools for type inference + + Returns: + JSON string increment to append to the output + """ + json_output = "" + + for char in raw_increment: + self._xml_tag_buffer += char + + if self._stream_state in [StreamState.INIT, StreamState.BETWEEN]: + if self._xml_tag_buffer.endswith(""): + self._stream_state = StreamState.IN_KEY + self._current_key = "" + self._xml_tag_buffer = "" + json_output += "{" if self._is_first_param else ", " + self._is_first_param = False + + elif self._stream_state == StreamState.IN_KEY: + if self._xml_tag_buffer.endswith(""): + self._current_key = self._xml_tag_buffer[:-10].strip() + self._xml_tag_buffer = "" + self._stream_state = StreamState.WAITING_VALUE + json_output += ( + json.dumps(self._current_key, ensure_ascii=False) + ": " + ) + + elif self._stream_state == StreamState.WAITING_VALUE: + if self._xml_tag_buffer.endswith(""): + self._stream_state = StreamState.IN_VALUE + self._current_value = "" + self._xml_tag_buffer = "" + self._value_started = False + # Determine and cache the value type at the start + self._cached_value_type = self._get_value_type( + func_name, self._current_key, tools + ) + + elif self._stream_state == StreamState.IN_VALUE: + if self._xml_tag_buffer.endswith(""): + final_value = self._xml_tag_buffer[:-12] + self._current_value += final_value + + # Use cached value type for consistency + value_type = self._cached_value_type or "string" + + if self._value_started: + # Output any remaining content + if final_value: + if value_type == "string": + json_output += json.dumps( + final_value, ensure_ascii=False + )[1:-1] + else: + json_output += final_value + # Always output closing quote for string type when value was started + if value_type == "string": + json_output += '"' + else: + # Value was never started (empty or complete in one chunk) + json_output += self._format_value_complete( + self._current_value, value_type + ) + + self._xml_tag_buffer = "" + self._stream_state = StreamState.BETWEEN + self._current_value = "" + self._value_started = False + self._cached_value_type = None # Reset cached type + else: + closing_tag = "" + is_potential_closing = len(self._xml_tag_buffer) <= len( + closing_tag + ) and closing_tag.startswith(self._xml_tag_buffer) + + if not is_potential_closing: + content = self._xml_tag_buffer + # Use cached value type for consistency + value_type = self._cached_value_type or "string" + + if value_type == "string": + if not self._value_started: + json_output += '"' + self._value_started = True + if content: + json_output += json.dumps(content, ensure_ascii=False)[ + 1:-1 + ] + self._current_value += content + self._xml_tag_buffer = "" + elif value_type == "number": + if content: + if not self._value_started: + self._value_started = True + json_output += content + self._current_value += content + self._xml_tag_buffer = "" + else: + # For object/array types, output as-is + if content: + if not self._value_started: + self._value_started = True + json_output += content + self._current_value += content + self._xml_tag_buffer = "" + + return json_output + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + """ + Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format. + Uses a state machine to convert XML to JSON incrementally for true character-by-character streaming. + Outputs JSON increments immediately as XML data arrives. + """ + self._buffer += new_text + current_text = self._buffer + + # Check if we have a tool call + has_tool_call = self.bot_token in current_text + + if not has_tool_call: + # Check if buffer could be the start of a tool call + # Keep buffer if it could be a partial match of bot_token + is_potential_start = any( + self.bot_token.startswith(current_text[-i:]) + for i in range(1, min(len(current_text), len(self.bot_token)) + 1) + ) + + if not is_potential_start: + # Not a potential tool call, return as normal text + # Must return the entire buffer (current_text), not just new_text, + # because buffer may contain previously accumulated characters like '<' + # that turned out not to be part of a tool call + output_text = current_text + self._buffer = "" + if self.eot_token in output_text: + output_text = output_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=output_text) + else: + # Could be start of tool call, keep buffering + return StreamingParseResult(normal_text="", calls=[]) + + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + calls: list[ToolCallItem] = [] + try: + # Try to match a partial or complete tool call + partial_match = re.search( + pattern=r"(.*?)(.*?)?(|$)", + string=current_text, + flags=re.DOTALL, + ) + if partial_match: + func_name = partial_match.group(1).strip() + func_args_raw = partial_match.group(2).strip() + is_tool_end = partial_match.group(3) + + # Initialize state if this is the first tool call + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._streamed_raw_length = 0 + self.current_tool_name_sent = False + self._reset_streaming_state() + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Send tool name first if not sent yet + if not self.current_tool_name_sent: + assert func_name, "func_name should not be empty" + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self._streamed_raw_length = 0 + self._reset_streaming_state() + # Store the tool call info + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + else: + # Process XML to JSON streaming + current_raw_length = len(func_args_raw) + + if current_raw_length > self._streamed_raw_length: + # Get the new raw XML content + raw_increment = func_args_raw[self._streamed_raw_length :] + + # Convert XML increment to JSON increment using state machine + json_increment = self._process_xml_to_json_streaming( + raw_increment, func_name, tools + ) + + if json_increment: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_increment, + ) + ) + self._last_arguments += json_increment + self.streamed_args_for_tool[ + self.current_tool_id + ] += json_increment + + # Update the streamed length + self._streamed_raw_length = current_raw_length + + if is_tool_end == self.eot_token: + if self._is_first_param: + empty_object = "{}" + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=empty_object, + ) + ) + self._last_arguments += empty_object + elif not self._last_arguments.endswith("}"): + closing_brace = "}" + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=closing_brace, + ) + ) + self._last_arguments += closing_brace + self.streamed_args_for_tool[ + self.current_tool_id + ] += closing_brace + + try: + pairs = self.func_arg_regex.findall(func_args_raw) + if pairs: + arguments = self._parse_argument_pairs( + pairs, func_name, tools + ) + self.prev_tool_call_arr[self.current_tool_id][ + "arguments" + ] = arguments + except Exception as e: + logger.debug( + f"Failed to parse arguments: {e}", exc_info=True + ) + + # Remove the completed tool call from buffer + self._buffer = current_text[partial_match.end(3) :] + + result = StreamingParseResult(normal_text="", calls=calls) + self.current_tool_id += 1 + self._last_arguments = "" + self.current_tool_name_sent = False + self._streamed_raw_length = 0 + self._reset_streaming_state() + return result + + return StreamingParseResult(normal_text="", calls=calls) + + except Exception as e: + logger.error(f"Error in parse_streaming_increment: {e}", exc_info=True) + return StreamingParseResult(normal_text=current_text) + + def _parse_argument_pairs( + self, pairs: List[Tuple[str, str]], func_name: str, tools: List[Tool] + ) -> Dict[str, Any]: + """Parse argument key-value pairs with type coercion. + + Args: + pairs: List of (key, value) tuples from regex matching + func_name: Name of the function + tools: List of available tools + + Returns: + Dictionary of parsed arguments + """ + arguments = {} + for arg_key, arg_value in pairs: + arg_key = arg_key.strip() + arg_value = arg_value.strip() + arg_type = get_argument_type(func_name, arg_key, tools) + parsed_value, is_good_json = parse_arguments(arg_value, arg_type) + + if arg_type == "string": + # Only convert to string if explicitly defined as string type + if isinstance(parsed_value, str): + arguments[arg_key] = parsed_value + elif isinstance(parsed_value, (dict, list)): + # If parsed as dict/list but schema says string, convert to JSON string + arguments[arg_key] = json.dumps(parsed_value, ensure_ascii=False) + else: + arguments[arg_key] = str(parsed_value) + elif arg_type is None: + # If type is not defined, keep the parsed value as-is + arguments[arg_key] = parsed_value if is_good_json else arg_value + else: + # For other types (number, object, array, etc.), use parsed value + arguments[arg_key] = parsed_value if is_good_json else arg_value + + return arguments + + def supports_structural_tag(self) -> bool: + return False + + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError() diff --git a/python/sglang/srt/function_call/glm4_moe_detector.py b/python/sglang/srt/function_call/glm4_moe_detector.py index b0fc78249aca..d1e684d2d4bc 100644 --- a/python/sglang/srt/function_call/glm4_moe_detector.py +++ b/python/sglang/srt/function_call/glm4_moe_detector.py @@ -16,9 +16,12 @@ def get_argument_type(func_name: str, arg_key: str, defined_tools: list): if func_name not in name2tool: return None tool = name2tool[func_name] - if arg_key not in tool.function.parameters["properties"]: + properties = (tool.function.parameters or {}).get("properties", {}) + if not isinstance(properties, dict): + properties = {} + if arg_key not in properties: return None - return tool.function.parameters["properties"][arg_key].get("type", None) + return properties[arg_key].get("type", None) def parse_arguments(json_value): diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 6483e41c1c62..c27816c2c93a 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -12,7 +12,7 @@ # limitations under the License. # ============================================================================== -"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights""" +"""Inference-only GLM-4.5, GLM-4.6 and GLM-4.7 model compatible with HuggingFace weights""" import logging from typing import Any, Dict, Iterable, List, Optional, Tuple, Union diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py index 939edb00196d..a395237ad968 100644 --- a/test/registered/function_call/test_function_call_parser.py +++ b/test/registered/function_call/test_function_call_parser.py @@ -7,6 +7,7 @@ from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector +from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector from sglang.srt.function_call.json_array_parser import JsonArrayParser from sglang.srt.function_call.kimik2_detector import KimiK2Detector from sglang.srt.function_call.llama32_detector import Llama32Detector @@ -2387,6 +2388,280 @@ def check_single_todos(tool_result, expected): check_single_todos(result, expected_output) +class TestGlm47MoeDetector(unittest.TestCase): + def setUp(self): + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date"}, + }, + "required": ["city", "date"], + }, + ), + ), + ] + self.detector = Glm47MoeDetector() + + def test_single_tool_call(self): + text = ( + "get_weather" + "cityBeijing" + "date2024-06-27" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(result.normal_text, "") + + def test_multiple_tool_calls(self): + text = ( + "get_weather" + "cityBeijing" + "date2024-06-27" + "" + "get_weather" + "cityShanghai" + "date2024-06-28" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(result.calls[1].name, "get_weather") + self.assertEqual( + result.calls[1].parameters, '{"city": "Shanghai", "date": "2024-06-28"}' + ) + self.assertEqual(result.normal_text, "") + + def test_streaming_tool_call(self): + """Test streaming incremental parsing of a tool call.""" + chunks = [ + "get_weather", + "cityBeijing", + "date2024-06-27", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + + def test_streaming_multiple_tool_calls(self): + """Test streaming incremental parsing of multiple tool calls.""" + chunks = [ + "get_weather", + "cityBeijing", + "date2024-06-27", + "get_weather", + "cityShanghai", + "date2024-06-28", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(tool_calls[1]["name"], "get_weather") + self.assertEqual( + tool_calls[1]["parameters"], '{"city": "Shanghai", "date": "2024-06-28"}' + ) + + def test_tool_call_id(self): + """Test that the buffer and state are reset after a tool call is completed.""" + chunks = [ + "get_weather", + "cityBeijing", + "date2024-06-27", + "", + ] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + self.assertEqual(self.detector.current_tool_id, 1) + + def test_invalid_tool_call(self): + """Test that invalid tool calls are handled correctly.""" + text = "invalid_funccityBeijing" + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 0) + + def test_partial_tool_call(self): + """Test parsing a partial tool call that spans multiple chunks.""" + chunks = [ + "get_weather", + "cityBeijing", + "date2024-06-27", + ] + + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + + def test_array_argument_with_escaped_json(self): + """Test that array arguments with escaped JSON are properly handled without double-escaping.""" + # Add a tool with array parameter + tools_with_array = [ + Tool( + type="function", + function=Function( + name="todo_write", + description="Write todos", + parameters={ + "type": "object", + "properties": { + "todos": { + "type": "array", + "description": "The updated todo list", + } + }, + "required": ["todos"], + }, + ), + ), + ] + + def check_params(result): + self.assertEqual(1, len(result.calls)) + self.assertEqual("todo_write", result.calls[0].name) + params = json.loads(result.calls[0].parameters) + self.assertIsInstance(params["todos"], list) + self.assertEqual(4, len(params["todos"])) + self.assertEqual("1", params["todos"][0]["id"]) + self.assertEqual( + "Check for hard-coded issues in the backend code", + params["todos"][0]["task"], + ) + self.assertEqual("in_progress", params["todos"][0]["status"]) + self.assertEqual("2", params["todos"][1]["id"]) + self.assertEqual( + "Check for hard-coded issues in the frontend code", + params["todos"][1]["task"], + ) + self.assertEqual("pending", params["todos"][1]["status"]) + self.assertEqual("3", params["todos"][2]["id"]) + self.assertEqual( + "Check for code violating the Single Responsibility Principle", + params["todos"][2]["task"], + ) + self.assertEqual("pending", params["todos"][2]["status"]) + self.assertEqual("4", params["todos"][3]["id"]) + self.assertEqual( + "Generate a rectification proposal report", params["todos"][3]["task"] + ) + self.assertEqual("pending", params["todos"][3]["status"]) + + # Simulate the raw response from GLM-4.6 model with normal and escaped JSON in XML + result = self.detector.detect_and_parse( + """todo_writetodos[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}] +""", + tools_with_array, + ) + check_params(result) + result = self.detector.detect_and_parse( + r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}] +""", + tools_with_array, + ) + check_params(result) + + def check_single_todos(tool_result, expected): + self.assertEqual(1, len(tool_result.calls)) + self.assertEqual("todo_write", tool_result.calls[0].name) + params = json.loads(tool_result.calls[0].parameters) + self.assertIsInstance(params["todos"], list) + self.assertEqual(1, len(params["todos"])) + self.assertEqual("1", params["todos"][0]["id"]) + self.assertEqual(expected, params["todos"][0]["task"]) + self.assertEqual("pending", params["todos"][0]["status"]) + + # Test with escaped JSON containing backslashes in content (e.g., Windows paths) + expected_path = r"Check file at C:\Users\test.txt" + result = self.detector.detect_and_parse( + """todo_writetodos[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_path) + result = self.detector.detect_and_parse( + r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_path) + + # Should contain literal \n, not actual newline + expected_output = r"Print \n to see newline" + result = self.detector.detect_and_parse( + """todo_writetodos[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_output) + result = self.detector.detect_and_parse( + r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_output) class TestJsonArrayParser(unittest.TestCase): def setUp(self): # Create sample tools for testing From aa19e9a6911fe02d8f0fb4accd055c517d35fa5f Mon Sep 17 00:00:00 2001 From: Leoyzen Date: Wed, 31 Dec 2025 06:32:35 +0800 Subject: [PATCH 2/3] Fix: Handle empty func_name and None values in GLM MoE detectors (#15754) Signed-off-by: Xinyuan Tong Co-authored-by: Xinyuan Tong Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> --- .../srt/function_call/glm47_moe_detector.py | 435 ++++-- .../srt/function_call/glm4_moe_detector.py | 248 +++- .../test_function_call_parser.py | 15 + .../function_call/test_glm47_moe_detector.py | 1176 +++++++++++++++++ 4 files changed, 1704 insertions(+), 170 deletions(-) create mode 100644 test/registered/function_call/test_glm47_moe_detector.py diff --git a/python/sglang/srt/function_call/glm47_moe_detector.py b/python/sglang/srt/function_call/glm47_moe_detector.py index 726737c5788a..0cf3a1b622e6 100644 --- a/python/sglang/srt/function_call/glm47_moe_detector.py +++ b/python/sglang/srt/function_call/glm47_moe_detector.py @@ -40,15 +40,27 @@ def get_argument_type( The type string (e.g., 'string', 'number', 'object') or None if not found """ name2tool = {tool.function.name: tool for tool in defined_tools} - if func_name not in name2tool: + + # Check if function exists + tool = name2tool.get(func_name) + if not tool: + return None + + # Get parameters safely using getattr + params = getattr(tool.function, "parameters", None) + if not isinstance(params, dict): return None - tool = name2tool[func_name] - properties = (tool.function.parameters or {}).get("properties", {}) + + # Navigate to the type using dict.get() for safe access + properties = params.get("properties") if not isinstance(properties, dict): - properties = {} - if arg_key not in properties: return None - return properties[arg_key].get("type", None) + + arg_spec = properties.get(arg_key) + if isinstance(arg_spec, dict): + return arg_spec.get("type") + + return None def _convert_to_number(value: str) -> Any: @@ -143,6 +155,10 @@ def __init__(self): self.current_tool_id = -1 self.current_tool_name_sent = False self._streamed_raw_length = 0 + self._tool_call_completed = False # Track if tool call has been completed + self._sent_empty_object = ( + False # Track if empty object has been sent for no-arg functions + ) self._reset_streaming_state() def _reset_streaming_state(self) -> None: @@ -156,6 +172,8 @@ def _reset_streaming_state(self) -> None: self._cached_value_type: Optional[str] = ( None # Cache the value type for consistency ) + self._tool_call_completed = False # Reset tool call completion status + self._sent_empty_object = False # Reset empty object sent status def has_tool_call(self, text: str) -> bool: """Check if the text contains a glm-4.5 / glm-4.6 format tool call.""" @@ -169,18 +187,38 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult :param tools: List of available tools. :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. """ - idx = text.find(self.bot_token) - normal_text = text[:idx].strip() if idx != -1 else text if self.bot_token not in text: - return StreamingParseResult(normal_text=normal_text, calls=[]) + return StreamingParseResult(normal_text=text, calls=[]) + + # Extract all normal text (before, between, and after tool calls) + normal_text_parts = [] + last_end = 0 + + # Find all tool call matches + for match in re.finditer(self.func_call_regex, text, re.DOTALL): + # Add text before this tool call + if match.start() > last_end: + normal_text_parts.append(text[last_end : match.start()]) + last_end = match.end() + + # Add any remaining text after the last tool call + if last_end < len(text): + normal_text_parts.append(text[last_end:]) + + # Combine all normal text parts + normal_text = "".join(normal_text_parts).strip() + + # Parse tool calls match_result_list = re.findall(self.func_call_regex, text, re.DOTALL) calls = [] try: for match_result in match_result_list: # Get function name func_detail = self.func_detail_regex.search(match_result) - func_name = func_detail.group(1) - func_args = func_detail.group(2) + if func_detail is None: + continue + func_name = func_detail.group(1) if func_detail.group(1) else "" + func_args = func_detail.group(2) if func_detail.group(2) else "" arguments = {} if func_args: pairs = self.func_arg_regex.findall(func_args) @@ -212,7 +250,11 @@ def _get_value_type(self, func_name: str, key: str, tools: List[Tool]) -> str: return arg_type # Auto-detect type from value (best effort) - first_chars = self._current_value.strip()[:10] if self._current_value else "" + first_chars = ( + self._current_value.strip()[:10] + if self._current_value and self._current_value.strip() + else "" + ) if first_chars: first_char = first_chars[0] if first_char.isdigit() or first_char in ["-", "."]: @@ -237,14 +279,14 @@ def _format_value_complete(self, value: str, value_type: str) -> str: return json.dumps(value, ensure_ascii=False) elif value_type == "number": try: - num = _convert_to_number(value.strip()) + num = _convert_to_number(value.strip() if value else "") return str(num) except (ValueError, AttributeError): # Fallback to string if not a valid number logger.warning( f"Failed to parse '{value}' as number, treating as string" ) - return json.dumps(str(value), ensure_ascii=False) + return json.dumps(str(value) if value else "", ensure_ascii=False) else: # For object/array types, return as-is (should already be valid JSON) return value @@ -369,6 +411,179 @@ def _process_xml_to_json_streaming( return json_output + def _extract_match_groups(self, match: re.Match) -> tuple[str, str, str]: + """Extract function name, arguments and end marker from regex match. + + Args: + match: Regex match object + + Returns: + (func_name, func_args_raw, is_tool_end) + """ + func_name = match.group(1).strip() + func_args_raw = match.group(2).strip() if match.group(2) else "" + is_tool_end = match.group(3) or "" + return func_name, func_args_raw, is_tool_end + + def _send_tool_name_if_needed( + self, func_name: str, has_arg_key: bool, is_tool_end: str + ) -> Optional[ToolCallItem]: + """Send tool name if needed. + + Args: + func_name: Function name + has_arg_key: Whether current text contains Optional[ToolCallItem]: + """Process streaming arguments. + + Args: + func_name: Function name + func_args_raw: Raw argument string + tools: List of available tools + + Returns: + Tool call item with parameter updates or None + """ + current_raw_length = len(func_args_raw) + + if current_raw_length <= self._streamed_raw_length: + return None + + # Get new raw XML content + raw_increment = func_args_raw[self._streamed_raw_length :] + + # Convert XML to JSON using state machine + json_increment = self._process_xml_to_json_streaming( + raw_increment, func_name, tools + ) + + # CRITICAL: Update streamed length BEFORE early return + # Even if json_increment is empty, the input has been consumed by the state machine + self._streamed_raw_length = current_raw_length + + if not json_increment: + return None + + # Update state + self._last_arguments += json_increment + self.streamed_args_for_tool[self.current_tool_id] += json_increment + + return ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_increment, + ) + + def _finalize_tool_call( + self, + func_name: str, + func_args_raw: str, + tools: List[Tool], + match_end_pos: int, + current_text: str, + ) -> List[ToolCallItem]: + """Complete tool call processing. + + Args: + func_name: Function name + func_args_raw: Raw argument string + tools: List of available tools + match_end_pos: Match end position + current_text: Current text + + Returns: + List of tool call items to add + """ + calls = [] + + # Handle no-arg function or need to close braces + if self._is_first_param and not self._sent_empty_object: + # No-arg function + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters="{}", + ) + ) + self._last_arguments += "{}" + self.streamed_args_for_tool[self.current_tool_id] += "{}" + self._sent_empty_object = True + elif not self._last_arguments.endswith("}") and not self._sent_empty_object: + # Need to close brace + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters="}", + ) + ) + self._last_arguments += "}" + self.streamed_args_for_tool[self.current_tool_id] += "}" + self._sent_empty_object = True + + # Parse final arguments + if func_args_raw: + try: + pairs = self.func_arg_regex.findall(func_args_raw) + if pairs: + arguments = self._parse_argument_pairs(pairs, func_name, tools) + self.prev_tool_call_arr[self.current_tool_id][ + "arguments" + ] = arguments + except Exception as e: + logger.debug(f"Failed to parse arguments: {e}", exc_info=True) + + # Clean buffer + self._buffer = current_text[match_end_pos:] + + # Reset state for next tool call + self._tool_call_completed = True + self.current_tool_id += 1 + self._last_arguments = "" + self.current_tool_name_sent = False + self._streamed_raw_length = 0 + self._reset_streaming_state() + + return calls + def parse_streaming_increment( self, new_text: str, tools: List[Tool] ) -> StreamingParseResult: @@ -405,140 +620,96 @@ def parse_streaming_increment( # Could be start of tool call, keep buffering return StreamingParseResult(normal_text="", calls=[]) + # Extract any text before the first bot_token and return it as normal_text + normal_text = "" + first_bot_token_idx = current_text.find(self.bot_token) + if first_bot_token_idx > 0: + normal_text = current_text[:first_bot_token_idx] + current_text = current_text[first_bot_token_idx:] + # Update buffer to only include from the bot token onwards + self._buffer = current_text + if not hasattr(self, "_tool_indices"): self._tool_indices = self._get_tool_indices(tools) calls: list[ToolCallItem] = [] try: # Try to match a partial or complete tool call + # Use a single flexible regex pattern that handles all cases partial_match = re.search( - pattern=r"(.*?)(.*?)?(|$)", - string=current_text, - flags=re.DOTALL, + r"(.*?)(?:()|$)", + current_text, + re.DOTALL, ) - if partial_match: - func_name = partial_match.group(1).strip() - func_args_raw = partial_match.group(2).strip() - is_tool_end = partial_match.group(3) - - # Initialize state if this is the first tool call - if self.current_tool_id == -1: - self.current_tool_id = 0 - self.prev_tool_call_arr = [] - self.streamed_args_for_tool = [""] - self._streamed_raw_length = 0 - self.current_tool_name_sent = False - self._reset_streaming_state() - - # Ensure we have enough entries in our tracking arrays - while len(self.prev_tool_call_arr) <= self.current_tool_id: - self.prev_tool_call_arr.append({}) - while len(self.streamed_args_for_tool) <= self.current_tool_id: - self.streamed_args_for_tool.append("") - - # Send tool name first if not sent yet - if not self.current_tool_name_sent: - assert func_name, "func_name should not be empty" - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=func_name, - parameters="", - ) - ) - self.current_tool_name_sent = True - self._streamed_raw_length = 0 - self._reset_streaming_state() - # Store the tool call info - self.prev_tool_call_arr[self.current_tool_id] = { - "name": func_name, - "arguments": {}, - } - else: - # Process XML to JSON streaming - current_raw_length = len(func_args_raw) - if current_raw_length > self._streamed_raw_length: - # Get the new raw XML content - raw_increment = func_args_raw[self._streamed_raw_length :] + if not partial_match: + return StreamingParseResult(normal_text=normal_text, calls=[]) - # Convert XML increment to JSON increment using state machine - json_increment = self._process_xml_to_json_streaming( - raw_increment, func_name, tools - ) + # Extract match groups using helper method + func_name, func_args_raw, is_tool_end = self._extract_match_groups( + partial_match + ) - if json_increment: - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters=json_increment, - ) - ) - self._last_arguments += json_increment - self.streamed_args_for_tool[ - self.current_tool_id - ] += json_increment - - # Update the streamed length - self._streamed_raw_length = current_raw_length - - if is_tool_end == self.eot_token: - if self._is_first_param: - empty_object = "{}" - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters=empty_object, - ) - ) - self._last_arguments += empty_object - elif not self._last_arguments.endswith("}"): - closing_brace = "}" - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters=closing_brace, - ) - ) - self._last_arguments += closing_brace - self.streamed_args_for_tool[ - self.current_tool_id - ] += closing_brace - - try: - pairs = self.func_arg_regex.findall(func_args_raw) - if pairs: - arguments = self._parse_argument_pairs( - pairs, func_name, tools - ) - self.prev_tool_call_arr[self.current_tool_id][ - "arguments" - ] = arguments - except Exception as e: - logger.debug( - f"Failed to parse arguments: {e}", exc_info=True - ) - - # Remove the completed tool call from buffer - self._buffer = current_text[partial_match.end(3) :] - - result = StreamingParseResult(normal_text="", calls=calls) - self.current_tool_id += 1 - self._last_arguments = "" - self.current_tool_name_sent = False - self._streamed_raw_length = 0 - self._reset_streaming_state() - return result - - return StreamingParseResult(normal_text="", calls=calls) + # Initialize tool call state if needed (keeping existing logic) + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._streamed_raw_length = 0 + self.current_tool_name_sent = False # Reset for new tool call + self._reset_streaming_state() + # Check if this is a continuation of an existing tool call or a new one + elif not self.current_tool_name_sent: + # Only increment tool_id if we're truly starting a NEW tool call + # Don't increment if this is just the first time we're processing + # a tool call that was received in the buffer + # The key insight: only increment when we've COMPLETED a previous tool call + # and now see another bot_token in new_text + pass # Remove the problematic auto-increment logic + + # Ensure tracking arrays are large enough (keeping existing logic) + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Determine if function name is complete by checking for in the full text + # This is important for streaming scenarios where args come in later chunks + has_arg_key = " Dict[str, Any]: diff --git a/python/sglang/srt/function_call/glm4_moe_detector.py b/python/sglang/srt/function_call/glm4_moe_detector.py index d1e684d2d4bc..3158e32166c3 100644 --- a/python/sglang/srt/function_call/glm4_moe_detector.py +++ b/python/sglang/srt/function_call/glm4_moe_detector.py @@ -87,8 +87,10 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult for match_result in match_result_list: # Get function name func_detail = self.func_detail_regex.search(match_result) - func_name = func_detail.group(1) - func_args = func_detail.group(2) + if func_detail is None: + continue + func_name = func_detail.group(1) if func_detail.group(1) else "" + func_args = func_detail.group(2) if func_detail.group(2) else "" pairs = self.func_arg_regex.findall(func_args) arguments = {} for arg_key, arg_value in pairs: @@ -112,47 +114,217 @@ def parse_streaming_increment( ) -> StreamingParseResult: """ Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format. + Uses a state machine to convert XML to JSON incrementally for true character-by-character streaming. + Outputs JSON increments immediately as XML data arrives. """ self._buffer += new_text current_text = self._buffer - start = current_text.find(self.bot_token) - if start == -1: - self._buffer = "" - if self.current_tool_id > 0: - current_text = "" - return StreamingParseResult(normal_text=current_text) - # find ensures we find the first self.eot_token so there will be at most one tool_call in current_text[:end+len(self.eot_token) - end = current_text.find(self.eot_token) - if end != -1: - # Initialize state if this is the first tool call - if self.current_tool_id == -1: - self.current_tool_id = 0 - self.prev_tool_call_arr = [] - self.streamed_args_for_tool = [""] - # Ensure we have enough entries in our tracking arrays - while len(self.prev_tool_call_arr) <= self.current_tool_id: - self.prev_tool_call_arr.append({}) - while len(self.streamed_args_for_tool) <= self.current_tool_id: - self.streamed_args_for_tool.append("") - result = self.detect_and_parse( - current_text[: end + len(self.eot_token)], tools=tools + # Check if we have a tool call + has_tool_call = self.bot_token in current_text + + if not has_tool_call: + # Check if buffer could be the start of a tool call + # Keep buffer if it could be a partial match of bot_token + is_potential_start = any( + self.bot_token.startswith(current_text[-i:]) + for i in range(1, min(len(current_text), len(self.bot_token)) + 1) + ) + + if not is_potential_start: + # Not a potential tool call, return as normal text + # Must return the entire buffer (current_text), not just new_text, + # because buffer may contain previously accumulated characters like '<' + # that turned out not to be part of a tool call + output_text = current_text + self._buffer = "" + if self.eot_token in output_text: + output_text = output_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=output_text) + else: + # Could be start of tool call, keep buffering + return StreamingParseResult(normal_text="", calls=[]) + + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + calls: list[ToolCallItem] = [] + try: + # Try to match a partial or complete tool call + partial_match = re.search( + pattern=r"(.*?)(?:\\n|\n)(.*?)(|$)", + string=current_text, + flags=re.DOTALL, ) - if result.calls: - self.prev_tool_call_arr[self.current_tool_id] = { - "name": result.calls[0].name, - "arguments": json.loads(result.calls[0].parameters), - } - self.streamed_args_for_tool[self.current_tool_id] = result.calls[ - 0 - ].parameters - result.calls[0].tool_index = self.current_tool_id - self.current_tool_id += 1 - self._buffer = current_text[end + len(self.eot_token) :] - return result - normal_text = current_text[:start] - self._buffer = current_text[start:] - return StreamingParseResult(normal_text=normal_text) + if partial_match: + func_name_raw = partial_match.group(1) + func_args_raw = partial_match.group(2) + is_tool_end = partial_match.group(3) + + # Only proceed if we have a non-empty function name + if func_name_raw is None or not func_name_raw.strip(): + # If we only have the start token without a function name, + # continue buffering until we get more content + return StreamingParseResult(normal_text="", calls=[]) + + func_name = func_name_raw.strip() + func_args_raw = func_args_raw.strip() if func_args_raw else "" + + # Initialize state if this is the first tool call + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._streamed_raw_length = 0 + self.current_tool_name_sent = False + self._reset_streaming_state() + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Send tool name first if not sent yet + if not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self._streamed_raw_length = 0 + self._reset_streaming_state() + # Store the tool call info + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + else: + # Process XML to JSON streaming + current_raw_length = len(func_args_raw) + + if current_raw_length > self._streamed_raw_length: + # Get the new raw XML content + raw_increment = func_args_raw[self._streamed_raw_length :] + + # Convert XML increment to JSON increment using state machine + json_increment = self._process_xml_to_json_streaming( + raw_increment, func_name, tools + ) + + # CRITICAL: Update streamed length BEFORE checking json_increment + # Even if json_increment is empty, the input has been consumed by the state machine + self._streamed_raw_length = current_raw_length + + if json_increment: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_increment, + ) + ) + self._last_arguments += json_increment + self.streamed_args_for_tool[ + self.current_tool_id + ] += json_increment + + if is_tool_end == self.eot_token: + if self._is_first_param: + empty_object = "{}" + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=empty_object, + ) + ) + self._last_arguments += empty_object + elif not self._last_arguments.endswith("}"): + closing_brace = "}" + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=closing_brace, + ) + ) + self._last_arguments += closing_brace + self.streamed_args_for_tool[ + self.current_tool_id + ] += closing_brace + + try: + pairs = self.func_arg_regex.findall(func_args_raw) + if pairs: + arguments = self._parse_argument_pairs( + pairs, func_name, tools + ) + self.prev_tool_call_arr[self.current_tool_id][ + "arguments" + ] = arguments + except Exception as e: + logger.debug( + f"Failed to parse arguments: {e}", exc_info=True + ) + + # Remove the completed tool call from buffer + self._buffer = current_text[partial_match.end(3) :] + + result = StreamingParseResult(normal_text="", calls=calls) + self.current_tool_id += 1 + self._last_arguments = "" + self.current_tool_name_sent = False + self._streamed_raw_length = 0 + self._reset_streaming_state() + return result + + return StreamingParseResult(normal_text="", calls=calls) + + except Exception as e: + logger.error(f"Error in parse_streaming_increment: {e}", exc_info=True) + return StreamingParseResult(normal_text=current_text) + + def _parse_argument_pairs( + self, pairs: List[Tuple[str, str]], func_name: str, tools: List[Tool] + ) -> Dict[str, Any]: + """Parse argument key-value pairs with type coercion. + + Args: + pairs: List of (key, value) tuples from regex matching + func_name: Name of the function + tools: List of available tools + + Returns: + Dictionary of parsed arguments + """ + arguments = {} + for arg_key, arg_value in pairs: + arg_key = arg_key.strip() + arg_value = arg_value.strip() + arg_type = get_argument_type(func_name, arg_key, tools) + parsed_value, is_good_json = parse_arguments(arg_value, arg_type) + + if arg_type == "string": + # Only convert to string if explicitly defined as string type + if isinstance(parsed_value, str): + arguments[arg_key] = parsed_value + elif isinstance(parsed_value, (dict, list)): + # If parsed as dict/list but schema says string, convert to JSON string + arguments[arg_key] = json.dumps(parsed_value, ensure_ascii=False) + else: + arguments[arg_key] = str(parsed_value) + elif arg_type is None: + # If type is not defined, keep the parsed value as-is + arguments[arg_key] = parsed_value if is_good_json else arg_value + else: + # For other types (number, object, array, etc.), use parsed value + arguments[arg_key] = parsed_value if is_good_json else arg_value + + return arguments def supports_structural_tag(self) -> bool: return False diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py index a395237ad968..81cabe07fa3a 100644 --- a/test/registered/function_call/test_function_call_parser.py +++ b/test/registered/function_call/test_function_call_parser.py @@ -2387,6 +2387,21 @@ def check_single_todos(tool_result, expected): ) check_single_todos(result, expected_output) + def test_empty_function_name_handling(self): + """Test that empty function name is handled gracefully without assertion error.""" + # This test simulates the issue where the model outputs only the start token without a function name + chunks = [ + "", # Start token only, no function name yet + "\n", # More content without function name + ] + + for chunk in chunks: + # Should not raise AssertionError: func_name should not be empty + result = self.detector.parse_streaming_increment(chunk, self.tools) + # Should return empty calls without error + self.assertIsInstance(result, StreamingParseResult) + self.assertEqual(result.calls, []) + class TestGlm47MoeDetector(unittest.TestCase): def setUp(self): diff --git a/test/registered/function_call/test_glm47_moe_detector.py b/test/registered/function_call/test_glm47_moe_detector.py new file mode 100644 index 000000000000..a046964064ce --- /dev/null +++ b/test/registered/function_call/test_glm47_moe_detector.py @@ -0,0 +1,1176 @@ +import json +import unittest + +from sglang.srt.entrypoints.openai.protocol import Function, Tool +from sglang.srt.function_call.core_types import StreamingParseResult +from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector +from sglang.test.ci.ci_register import register_cpu_ci + +register_cpu_ci(1.0, "default") + + +class TestGlm47MoeDetector(unittest.TestCase): + def setUp(self): + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date"}, + }, + "required": ["city", "date"], + }, + ), + ), + ] + self.detector = Glm47MoeDetector() + + # ==================== Basic Parsing Tests (5) ==================== + + def test_single_tool_call(self): + """ + Test basic single tool call parsing. + + Scenario: Parse a complete tool call with two string parameters in a single text block. + Purpose: Verify the detector can correctly identify and extract function name and parameters + from a simple, well-formed tool call. + """ + text = ( + "get_weather" + "cityBeijing" + "date2024-06-27" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(result.normal_text, "") + + def test_multiple_tool_calls(self): + """ + Test parsing multiple consecutive tool calls. + + Scenario: Parse two complete tool calls back-to-back without any text in between. + Purpose: Verify the detector correctly handles multiple tool calls and resets state + between calls to avoid parameter leakage or ID conflicts. + """ + text = ( + "get_weather" + "cityBeijing" + "date2024-06-27" + "" + "get_weather" + "cityShanghai" + "date2024-06-28" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual( + result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(result.calls[1].name, "get_weather") + self.assertEqual( + result.calls[1].parameters, '{"city": "Shanghai", "date": "2024-06-28"}' + ) + self.assertEqual(result.normal_text, "") + + def test_no_arg_function_non_streaming(self): + """ + Test no-argument function call without streaming. + + Scenario: Parse a tool call for a function that has no parameters (empty properties). + Purpose: Verify the detector generates a single empty object "{}" for no-argument functions + and does not duplicate empty parameter objects. + """ + tools_with_no_args = [ + Tool( + type="function", + function=Function( + name="list_filenames", + description="List filenames", + parameters={ + "type": "object", + "properties": {}, + }, + ), + ), + ] + + text = "list_filenames" + result = self.detector.detect_and_parse(text, tools_with_no_args) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "list_filenames") + params = json.loads(result.calls[0].parameters) + self.assertEqual(params, {}) + + def test_invalid_tool_call(self): + """ + Test handling of invalid tool calls. + + Scenario: Attempt to parse a tool call with a function name that doesn't exist in the tool list. + Purpose: Verify the detector gracefully rejects invalid function calls and returns no calls + rather than throwing an error or accepting invalid input. + """ + text = "invalid_funccityBeijing" + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 0) + + def test_array_argument_with_escaped_json(self): + """ + Test array arguments containing escaped JSON strings. + + Scenario: Parse tool calls with array parameters containing nested JSON objects with + escaped quotes (both backslash-escaped and raw escaped strings). + Purpose: Verify the detector properly handles JSON escaping without double-escaping, + preserving special characters like backslashes in paths and newline sequences. + """ + tools_with_array = [ + Tool( + type="function", + function=Function( + name="todo_write", + description="Write todos", + parameters={ + "type": "object", + "properties": { + "todos": { + "type": "array", + "description": "The updated todo list", + } + }, + "required": ["todos"], + }, + ), + ), + ] + + def check_params(result): + self.assertEqual(1, len(result.calls)) + self.assertEqual("todo_write", result.calls[0].name) + params = json.loads(result.calls[0].parameters) + self.assertIsInstance(params["todos"], list) + self.assertEqual(4, len(params["todos"])) + self.assertEqual("1", params["todos"][0]["id"]) + self.assertEqual( + "Check for hard-coded issues in the backend code", + params["todos"][0]["task"], + ) + self.assertEqual("in_progress", params["todos"][0]["status"]) + self.assertEqual("2", params["todos"][1]["id"]) + self.assertEqual( + "Check for hard-coded issues in the frontend code", + params["todos"][1]["task"], + ) + self.assertEqual("pending", params["todos"][1]["status"]) + self.assertEqual("3", params["todos"][2]["id"]) + self.assertEqual( + "Check for code violating the Single Responsibility Principle", + params["todos"][2]["task"], + ) + self.assertEqual("pending", params["todos"][2]["status"]) + self.assertEqual("4", params["todos"][3]["id"]) + self.assertEqual( + "Generate a rectification proposal report", params["todos"][3]["task"] + ) + self.assertEqual("pending", params["todos"][3]["status"]) + + # Test with normal escaped JSON in XML + result = self.detector.detect_and_parse( + """todo_writetodos[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}] +""", + tools_with_array, + ) + check_params(result) + + # Test with raw string escaped JSON + result = self.detector.detect_and_parse( + r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}] +""", + tools_with_array, + ) + check_params(result) + + def check_single_todos(tool_result, expected): + self.assertEqual(1, len(tool_result.calls)) + self.assertEqual("todo_write", tool_result.calls[0].name) + params = json.loads(tool_result.calls[0].parameters) + self.assertIsInstance(params["todos"], list) + self.assertEqual(1, len(params["todos"])) + self.assertEqual("1", params["todos"][0]["id"]) + self.assertEqual(expected, params["todos"][0]["task"]) + self.assertEqual("pending", params["todos"][0]["status"]) + + # Test with escaped backslashes (Windows paths) + expected_path = r"Check file at C:\Users\test.txt" + result = self.detector.detect_and_parse( + """todo_writetodos[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_path) + + # Test with literal backslash-n (not newline) + expected_output = r"Print \n to see newline" + result = self.detector.detect_and_parse( + """todo_writetodos[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""", + tools_with_array, + ) + check_single_todos(result, expected_output) + + # ==================== MTP Core Scenarios (3) ==================== + + def test_mtp_func_and_string_split(self): + """ + Test MTP-style function name and string parameter value splitting across chunks. + + Scenario: Simulate Model Token Provider (MTP) behavior where function names and string + parameter values are split mid-word across multiple chunks. + Purpose: This is the MOST CRITICAL test - verify the detector correctly reassembles: + - Function name split as "create_ta" + "sk" + - String values split as "Go to Bei" + "jing" and "San Fran" + "cisco" + These splits mimic real MTP output where tokenization breaks words arbitrarily. + """ + tools = [ + Tool( + type="function", + function=Function( + name="create_task", + parameters={ + "type": "object", + "properties": { + "title": {"type": "string"}, + "location": {"type": "string"}, + }, + }, + ), + ), + ] + + chunks = [ + "I'll create a task.", # normal text before tool call + "create_ta", # function name split mid-word + "sktitleGo to Bei", # function name completes, param value splits + "jing", # first parameter value completes + "locationSan Fran", # second parameter value splits + "cisco", # second parameter and tool call complete + ] + + detector = Glm47MoeDetector() + all_calls = [] + all_normal_text = "" + + for chunk in chunks: + result = detector.parse_streaming_increment(chunk, tools) + all_calls.extend(result.calls) + all_normal_text += result.normal_text + + # Verify normal text is preserved + self.assertEqual(all_normal_text, "I'll create a task.") + + # Verify function call + func_calls = [c for c in all_calls if c.name] + self.assertEqual(len(func_calls), 1) + self.assertEqual( + func_calls[0].name, "create_task" + ) # "create_ta" + "sk" reassembled + + # Verify parameter reassembly + full_params = "".join([c.parameters for c in all_calls if c.parameters]) + params = json.loads(full_params) + self.assertEqual( + params["title"], "Go to Beijing" + ) # "Go to Bei" + "jing" reassembled + self.assertEqual( + params["location"], "San Francisco" + ) # "San Fran" + "cisco" reassembled + + def test_mtp_noarg_and_multiple_calls(self): + """ + Test MTP-style no-argument function and multiple tool calls with state reset. + + Scenario: Stream a no-argument function call followed by a regular function call, + simulating MTP's output pattern where function completion triggers state reset. + Purpose: Verify: + - No-argument functions emit exactly ONE empty object "{}", not duplicates + - State properly resets between consecutive tool calls (tool_index increments) + - Second tool call doesn't inherit parameters from first call + """ + tools = [ + Tool( + type="function", + function=Function( + name="list_files", + parameters={ + "type": "object", + "properties": {}, + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_weather", + parameters={ + "type": "object", + "properties": { + "city": {"type": "string"}, + }, + }, + ), + ), + ] + + chunks = [ + "list_files", # no-arg function, complete in one chunk + "get_weathercityBeijing", + ] + + detector = Glm47MoeDetector() + all_calls = [] + + for chunk in chunks: + result = detector.parse_streaming_increment(chunk, tools) + all_calls.extend(result.calls) + + # Verify two distinct tool calls + func_calls = [c for c in all_calls if c.name] + self.assertEqual(len(func_calls), 2) + self.assertEqual(func_calls[0].name, "list_files") + self.assertEqual(func_calls[1].name, "get_weather") + + # Verify no duplicate empty objects for no-arg function + empty_object_calls = [c for c in all_calls if c.parameters == "{}"] + self.assertLessEqual( + len(empty_object_calls), + 1, + "No-argument function should emit at most one empty object", + ) + + # Verify second call has correct parameters + weather_params = [ + c.parameters for c in all_calls if c.parameters and c.parameters != "{}" + ] + if weather_params: + full_params = "".join(weather_params) + params = json.loads(full_params) + self.assertEqual(params["city"], "Beijing") + + def test_mtp_number_and_complex_json(self): + """ + Test MTP-style number parameters and complex JSON array splitting. + + Scenario: Parse tool calls with number parameters (int and float) and JSON arrays + split across chunks, including splits within JSON structure. + Purpose: Verify: + - Number types (5.5, 10) are preserved as numbers, not strings + - JSON array content split as "description" + ": \"" maintains validity + - Nested JSON objects in arrays are correctly reconstructed + """ + tools = [ + Tool( + type="function", + function=Function( + name="create_todos", + parameters={ + "type": "object", + "properties": { + "priority": {"type": "number"}, + "count": {"type": "integer"}, + "items": {"type": "array"}, + }, + }, + ), + ), + ] + + chunks = [ + "create_todos", + "priority5.5", # float number + "count10", # integer number + 'items[{"description', # JSON array splits mid-key + '": "Test', # key completes, value starts + 'Todo 1"}, {"description": "TestTodo 2"}]', + ] + + detector = Glm47MoeDetector() + all_calls = [] + + for chunk in chunks: + result = detector.parse_streaming_increment(chunk, tools) + all_calls.extend(result.calls) + + # Verify function name + func_calls = [c for c in all_calls if c.name] + self.assertEqual(len(func_calls), 1) + self.assertEqual(func_calls[0].name, "create_todos") + + # Verify parameters - numbers and JSON array + full_params = "".join([c.parameters for c in all_calls if c.parameters]) + params = json.loads(full_params) + + # Number types should be preserved + self.assertIsInstance(params["priority"], (int, float)) + self.assertEqual(params["priority"], 5.5) + self.assertIsInstance(params["count"], int) + self.assertEqual(params["count"], 10) + + # JSON array should be correctly reconstructed + self.assertIsInstance(params["items"], list) + self.assertEqual(len(params["items"]), 2) + self.assertEqual(params["items"][0]["description"], "TestTodo 1") + self.assertEqual(params["items"][1]["description"], "TestTodo 2") + + # ==================== Streaming Basics (3) ==================== + + def test_streaming_tool_call(self): + """ + Test basic streaming incremental parsing of a single tool call. + + Scenario: Parse a tool call split across 4 chunks with natural boundaries + (function name, first param, second param, closing tag). + Purpose: Verify basic streaming functionality works correctly and accumulates + parameters progressively across chunks. + """ + chunks = [ + "get_weather", + "cityBeijing", + "date2024-06-27", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + + def test_streaming_multiple_tool_calls(self): + """ + Test streaming incremental parsing of multiple consecutive tool calls. + + Scenario: Stream two complete tool calls with the transition "" + occurring within a single chunk. + Purpose: Verify streaming correctly handles multiple tool calls and properly increments + tool_index for each new call. + """ + chunks = [ + "get_weather", + "cityBeijing", + "date2024-06-27", + "get_weather", # two tool calls transition in same chunk + "cityShanghai", + "date2024-06-28", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["name"], "get_weather") + self.assertEqual( + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' + ) + self.assertEqual(tool_calls[1]["name"], "get_weather") + self.assertEqual( + tool_calls[1]["parameters"], '{"city": "Shanghai", "date": "2024-06-28"}' + ) + + def test_normal_text_before_tool_call(self): + """ + Test preservation of normal text (including punctuation) before tool calls. + + Scenario: Parse chunks containing normal text with various punctuation marks + (English and Chinese) immediately followed by tool call tags. + Purpose: Verify normal text is preserved in result.normal_text and not lost when + tool call parsing begins. This consolidates 6 previous Chinese punctuation tests. + """ + tools = [ + Tool( + type="function", + function=Function( + name="list_dir", + parameters={ + "type": "object", + "properties": { + "path": {"type": "string"}, + }, + }, + ), + ), + ] + + test_cases = [ + ("Sure, let me help.list_dir", "English with period"), + ("结构:list_dir", "Chinese colon"), + ("问题。list_dir", "Chinese period"), + ("Complete!list_dir", "English exclamation"), + ("说明;list_dir", "Chinese semicolon"), + ] + + for text, description in test_cases: + with self.subTest(description=description): + detector = Glm47MoeDetector() + result = detector.parse_streaming_increment(text, tools) + + before_token = text.split("")[0] + self.assertIn( + before_token, + result.normal_text, + f"Should preserve '{before_token}' in '{description}'", + ) + + # ==================== Boundary Cases (9) ==================== + + def test_boundary_empty_param_value(self): + """ + Test handling of empty parameter values. + + Scenario: Parse a tool call where a parameter value is an empty string. + Purpose: Verify the detector correctly handles empty strings as valid parameter values + and doesn't skip or error on them. + """ + tools = [ + Tool( + type="function", + function=Function( + name="create_note", + parameters={ + "type": "object", + "properties": { + "title": {"type": "string"}, + "content": {"type": "string"}, + }, + }, + ), + ), + ] + + text = "create_notetitleTestcontent" + result = self.detector.detect_and_parse(text, tools) + + self.assertEqual(len(result.calls), 1) + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["title"], "Test") + self.assertEqual(params["content"], "") # empty string should be preserved + + def test_boundary_param_value_extreme_split(self): + """ + Test extreme parameter value splitting - one character per chunk. + + Scenario: Stream a parameter value where each character arrives in a separate chunk, + representing worst-case MTP tokenization. + Purpose: Stress test the buffer reassembly mechanism to ensure it can handle + extremely granular chunk boundaries without data loss or corruption. + """ + tools = [ + Tool( + type="function", + function=Function( + name="search", + parameters={ + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + }, + ), + ), + ] + + chunks = [ + "searchqueryN", + "e", + "w ", + "Y", + "o", + "rk", + ] + + detector = Glm47MoeDetector() + all_calls = [] + + for chunk in chunks: + result = detector.parse_streaming_increment(chunk, tools) + all_calls.extend(result.calls) + + full_params = "".join([c.parameters for c in all_calls if c.parameters]) + params = json.loads(full_params) + self.assertEqual( + params["query"], "New York" + ) # all characters correctly reassembled + + def test_boundary_param_value_with_special_chars(self): + """ + Test parameter values containing special characters and escape sequences. + + Scenario: Parse parameter values with quotes, backslashes, newlines, and other + special characters that require JSON escaping. + Purpose: Verify special characters are properly escaped/unescaped and preserved + through the parsing pipeline without corruption. + """ + tools = [ + Tool( + type="function", + function=Function( + name="execute_command", + parameters={ + "type": "object", + "properties": { + "command": {"type": "string"}, + }, + }, + ), + ), + ] + + # Test with single quotes (no escaping needed) + text = "execute_commandcommandecho 'Hello World'" + result = self.detector.detect_and_parse(text, tools) + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["command"], "echo 'Hello World'") + + # Test with spaces and special chars that don't need escaping + text = "execute_commandcommandecho Hello & World" + result = self.detector.detect_and_parse(text, tools) + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["command"], "echo Hello & World") + + def test_boundary_json_deeply_nested(self): + """ + Test deeply nested JSON structures in parameter values. + + Scenario: Parse a parameter containing a deeply nested JSON object with multiple levels. + Purpose: Verify the detector can handle complex nested structures without stack overflow + or parsing errors. + """ + tools = [ + Tool( + type="function", + function=Function( + name="process_data", + parameters={ + "type": "object", + "properties": { + "data": {"type": "object"}, + }, + }, + ), + ), + ] + + nested_json = ( + '{"level1": {"level2": {"level3": {"level4": {"value": "deep"}}}}}' + ) + text = f"process_datadata{nested_json}" + + result = self.detector.detect_and_parse(text, tools) + params = json.loads(result.calls[0].parameters) + + # Navigate through nested structure + self.assertEqual( + params["data"]["level1"]["level2"]["level3"]["level4"]["value"], "deep" + ) + + def test_boundary_json_empty_structures(self): + """ + Test empty JSON structures (empty objects and arrays) in parameters. + + Scenario: Parse parameters containing empty objects {} and empty arrays []. + Purpose: Verify empty structures are preserved and not confused with no-argument + function empty parameter generation. + """ + tools = [ + Tool( + type="function", + function=Function( + name="create_structure", + parameters={ + "type": "object", + "properties": { + "empty_obj": {"type": "object"}, + "empty_arr": {"type": "array"}, + }, + }, + ), + ), + ] + + text = "create_structureempty_obj{}empty_arr[]" + result = self.detector.detect_and_parse(text, tools) + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["empty_obj"], {}) + self.assertEqual(params["empty_arr"], []) + + def test_boundary_multi_tags_one_chunk(self): + """ + Test multiple XML tags appearing in a single chunk. + + Scenario: Parse chunks where multiple complete tags (arg_key, arg_value, etc.) + appear together without any chunk boundaries between them. + Purpose: Verify the regex-based tag extraction correctly handles multiple tags + in one chunk and processes them in the correct order. + """ + tools = [ + Tool( + type="function", + function=Function( + name="multi_param", + parameters={ + "type": "object", + "properties": { + "a": {"type": "string"}, + "b": {"type": "string"}, + "c": {"type": "string"}, + }, + }, + ), + ), + ] + + # All three parameters in one chunk + text = "multi_parama1b2c3" + result = self.detector.detect_and_parse(text, tools) + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["a"], "1") + self.assertEqual(params["b"], "2") + self.assertEqual(params["c"], "3") + + def test_boundary_normal_text_mixed_with_tool(self): + """ + Test normal text interleaved with tool calls. + + Scenario: Parse text with normal text before and after tool calls. + Purpose: Verify normal text segments are correctly separated from tool call parsing + and preserved in the normal_text output. + + NOTE: Currently, the detector only captures text BEFORE the first tool call. + Text after tool calls is not returned (known limitation). + """ + tools = [ + Tool( + type="function", + function=Function( + name="action", + parameters={ + "type": "object", + "properties": {}, + }, + ), + ), + ] + + text = "First I'll do this.actionThen I'll do that." + result = self.detector.detect_and_parse(text, tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "action") + # Currently only text before tool call is captured + self.assertIn("First I'll do this.", result.normal_text) + # TODO: Text after tool call should also be preserved but currently isn't + self.assertIn("Then I'll do that.", result.normal_text) + + def test_boundary_number_edge_values(self): + """ + Test edge-case number values (zero, negative, scientific notation). + + Scenario: Parse parameters with various numeric edge cases to ensure proper type handling. + Purpose: Verify the detector correctly preserves number types for edge values and doesn't + convert them to strings or lose precision. + """ + tools = [ + Tool( + type="function", + function=Function( + name="calculate", + parameters={ + "type": "object", + "properties": { + "zero": {"type": "number"}, + "negative": {"type": "number"}, + "large": {"type": "number"}, + }, + }, + ), + ), + ] + + text = "calculatezero0negative-42.5large1e10" + result = self.detector.detect_and_parse(text, tools) + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["zero"], 0) + self.assertEqual(params["negative"], -42.5) + self.assertEqual(params["large"], 1e10) + + def test_boundary_type_string_with_numeric_content(self): + """ + Test string parameters that contain numeric-looking content. + + Scenario: Parse string parameters with values like "123" or "45.67" that look like + numbers but should remain strings based on parameter schema. + Purpose: Verify type preservation based on schema definition, not content appearance. + """ + tools = [ + Tool( + type="function", + function=Function( + name="store_data", + parameters={ + "type": "object", + "properties": { + "id": { + "type": "string" + }, # string type despite numeric content + "code": {"type": "string"}, + }, + }, + ), + ), + ] + + text = "store_dataid12345code67.89" + result = self.detector.detect_and_parse(text, tools) + + params = json.loads(result.calls[0].parameters) + # Should be strings, not numbers + self.assertIsInstance(params["id"], str) + self.assertIsInstance(params["code"], str) + self.assertEqual(params["id"], "12345") + self.assertEqual(params["code"], "67.89") + + # ==================== Error Handling (2) ==================== + + def test_error_undefined_tool(self): + """ + Test error handling for undefined tool names. + + Scenario: Attempt to call a function that doesn't exist in the provided tools list. + Purpose: Verify the detector gracefully handles undefined tools by returning an empty + call list rather than crashing or producing malformed output. + """ + text = "nonexistent_functionparamvalue" + result = self.detector.detect_and_parse(text, self.tools) + + # Should not crash, should return empty calls + self.assertEqual(len(result.calls), 0) + + def test_error_incomplete_buffer_at_end(self): + """ + Test handling of incomplete tool calls at end of stream. + + Scenario: Streaming ends with an incomplete tool call (e.g., missing closing tag). + Purpose: Verify the detector handles incomplete buffers gracefully without throwing + exceptions, as streaming may end mid-parse in real scenarios. + """ + chunks = [ + "get_weathercityBeijing", + # Stream ends here, no closing tags + ] + + detector = Glm47MoeDetector() + + for chunk in chunks: + result = detector.parse_streaming_increment(chunk, self.tools) + # Should not crash + self.assertIsInstance(result, StreamingParseResult) + + # Incomplete call should not be in results + # (or may be partially present - main thing is no exception) + + # ==================== Streamed Raw Length Bug Tests (3) ==================== + + def test_streamed_raw_length_incomplete_xml_tag(self): + """ + Test that _streamed_raw_length is updated even when json_increment is empty. + + Scenario: Stream XML content that is split at an incomplete tag boundary, + causing the state machine to buffer without producing JSON output. + Purpose: Verify that _streamed_raw_length is updated regardless of whether + json_increment is empty, preventing reprocessing of the same input. + + This tests the bug where: + 1. raw_increment is extracted from func_args_raw[self._streamed_raw_length:] + 2. _process_xml_to_json_streaming() returns empty string (buffering state) + 3. If _streamed_raw_length is NOT updated before the early return, + the next call will reprocess the same raw_increment + """ + tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + parameters={ + "type": "object", + "properties": { + "city": {"type": "string"}, + "temperature": {"type": "number"}, + }, + }, + ), + ), + ] + + # Simulate streaming chunks where XML tags are split + chunks = [ + "get_weather", + "cityBei", # Split in middle of value + "jing", # Complete the value + "temperature2", # Split numeric value + "5", + ] + + detector = Glm47MoeDetector() + all_calls = [] + collected_params = "" + + for i, chunk in enumerate(chunks): + result = detector.parse_streaming_increment(chunk, tools) + all_calls.extend(result.calls) + + # Collect parameters + for call in result.calls: + if call.parameters: + collected_params += call.parameters + + # Verify complete parameters were collected without duplication + if collected_params: + params = json.loads(collected_params) + self.assertEqual(params["city"], "Beijing") + self.assertEqual(params["temperature"], 25) + + # Critical: Verify no duplicate JSON output due to reprocessing + # Count occurrences of "city" key - should appear exactly once + city_count = collected_params.count('"city"') + self.assertEqual( + city_count, + 1, + f"'city' key appears {city_count} times, expected 1. " + f"This indicates input reprocessing bug.", + ) + + def test_streamed_raw_length_tag_split_across_chunks(self): + """ + Test _streamed_raw_length update when tag is split across chunk boundaries. + + Scenario: XML tags themselves are split across chunks (e.g., ""). + Purpose: Verify that even when the state machine is buffering partial tags, + _streamed_raw_length is correctly updated to prevent reprocessing. + """ + tools = [ + Tool( + type="function", + function=Function( + name="search", + parameters={ + "type": "object", + "properties": { + "query": {"type": "string"}, + "limit": {"type": "integer"}, + }, + }, + ), + ), + ] + + # Split tags in extreme positions + chunks = [ + "searchqueryPython progra", # Complete tag, split value + "mminglimit10", + ] + + detector = Glm47MoeDetector() + all_params = "" + + for chunk in chunks: + result = detector.parse_streaming_increment(chunk, tools) + for call in result.calls: + if call.parameters: + all_params += call.parameters + + # Verify correct reassembly + params = json.loads(all_params) + self.assertEqual(params["query"], "Python programming") + self.assertEqual(params["limit"], 10) + + # Verify no duplication in output + query_count = all_params.count('"query"') + limit_count = all_params.count('"limit"') + self.assertEqual(query_count, 1, "query key duplicated - reprocessing bug") + self.assertEqual(limit_count, 1, "limit key duplicated - reprocessing bug") + + def test_streamed_raw_length_buffer_only_partial_tag(self): + """ + Test that _streamed_raw_length updates even when state machine returns empty. + + Scenario: Send increment that is ONLY a partial opening tag that state machine + must buffer completely without producing any JSON output. + Purpose: Force json_increment to be empty string to expose the bug where + _streamed_raw_length is not updated before early return. + """ + tools = [ + Tool( + type="function", + function=Function( + name="test_func", + parameters={ + "type": "object", + "properties": { + "key1": {"type": "string"}, + }, + }, + ), + ), + ] + + # Manually call _process_arguments_streaming to have precise control + detector = Glm47MoeDetector() + detector.current_tool_id = 0 + detector.current_tool_name_sent = True + detector._reset_streaming_state() + detector.streamed_args_for_tool = [""] + detector._streamed_raw_length = 0 + + # First call: Complete tag that produces JSON output + func_args_1 = "key1va" + result_1 = detector._process_arguments_streaming( + "test_func", func_args_1, tools + ) + + # Should produce JSON output: {"key1": "va (partial) + self.assertIsNotNone(result_1) + self.assertGreater(len(result_1.parameters), 0) + initial_length = detector._streamed_raw_length + self.assertEqual(initial_length, len(func_args_1)) + + # Second call: Add just partial closing tag - state machine will buffer this + # without producing JSON (it's waiting to see if is complete) + func_args_2 = func_args_1 + "<" # Add partial tag + result_2 = detector._process_arguments_streaming( + "test_func", func_args_2, tools + ) + + # This is the critical test: if _streamed_raw_length is NOT updated when + # json_increment is empty, then detector._streamed_raw_length will still be + # at initial_length, and the next call will reprocess the "<" character + + # Check if length was updated (bug test) + updated_length = detector._streamed_raw_length + + # BUG: If code has bug, updated_length will equal initial_length + # FIXED: If code is correct, updated_length should equal len(func_args_2) + self.assertEqual( + updated_length, + len(func_args_2), + "Bug detected: _streamed_raw_length not updated when json_increment is empty. " + f"Expected {len(func_args_2)}, got {updated_length}", + ) + + def test_streamed_raw_length_multiple_empty_returns(self): + """ + Test consecutive chunks that produce empty json_increment. + + Scenario: Multiple consecutive chunks that all result in empty json_increment + as the state machine buffers complex nested structures. + Purpose: Verify _streamed_raw_length advances correctly through multiple + empty-return cycles without getting stuck or reprocessing. + """ + tools = [ + Tool( + type="function", + function=Function( + name="update_settings", + parameters={ + "type": "object", + "properties": { + "name": {"type": "string"}, + "value": {"type": "string"}, + }, + }, + ), + ), + ] + + # Split XML at positions that may cause state machine buffering + chunks = [ + "update_settingsna", # Split in tag name + "meco", # Complete tag start, split value # codespell:ignore ue + "nf", # Continue value + "ig_v1val", # Complete value, split next key + "ueena", # Complete key name, split value # codespell:ignore ue + "bled", # Complete everything + ] + + detector = Glm47MoeDetector() + all_params = "" + + for i, chunk in enumerate(chunks): + result = detector.parse_streaming_increment(chunk, tools) + + for call in result.calls: + if call.parameters: + all_params += call.parameters + + # Verify final output is correct + self.assertGreater(len(all_params), 0, "Should have generated some parameters") + params = json.loads(all_params) + self.assertEqual(params["name"], "config_v1") + self.assertEqual(params["value"], "enabled") + + # Verify no duplicate keys due to reprocessing + name_count = all_params.count('"name"') + value_count = all_params.count('"value"') + self.assertEqual( + name_count, + 1, + f"'name' appears {name_count} times - indicates reprocessing bug", + ) + self.assertEqual( + value_count, + 1, + f"'value' appears {value_count} times - indicates reprocessing bug", + ) + + +if __name__ == "__main__": + unittest.main() From 5bdb7334411de1d2e595aec2d05e41435faa1e0f Mon Sep 17 00:00:00 2001 From: Leoyzen Date: Sat, 10 Jan 2026 01:21:10 +0800 Subject: [PATCH 3/3] Fix GLM-4.7 MoE Detector complex JSON Schema type parsing (#15753) --- .../srt/function_call/glm47_moe_detector.py | 53 +- .../srt/function_call/glm4_moe_detector.py | 371 +++++++++- python/sglang/srt/function_call/utils.py | 105 ++- .../function_call/test_glm47_moe_detector.py | 681 +++++++++++++++++- 4 files changed, 1165 insertions(+), 45 deletions(-) diff --git a/python/sglang/srt/function_call/glm47_moe_detector.py b/python/sglang/srt/function_call/glm47_moe_detector.py index 0cf3a1b622e6..0b25b11eb379 100644 --- a/python/sglang/srt/function_call/glm47_moe_detector.py +++ b/python/sglang/srt/function_call/glm47_moe_detector.py @@ -12,6 +12,7 @@ ToolCallItem, _GetInfoFunc, ) +from sglang.srt.function_call.utils import infer_type_from_json_schema logger = logging.getLogger(__name__) @@ -31,6 +32,14 @@ def get_argument_type( ) -> Optional[str]: """Get the expected type of a function argument from tool definitions. + Supports complex JSON Schema definitions including: + - Direct type field (including type arrays) + - anyOf/oneOf: parameter can be any of multiple types + - enum: parameter must be one of enum values + - allOf: parameter must satisfy all type definitions + - properties: inferred as object type + - items: inferred as array type + Args: func_name: Name of the function/tool arg_key: Name of the argument @@ -58,7 +67,8 @@ def get_argument_type( arg_spec = properties.get(arg_key) if isinstance(arg_spec, dict): - return arg_spec.get("type") + # Use the new type inference function for complex JSON Schema support + return infer_type_from_json_schema(arg_spec) return None @@ -243,25 +253,48 @@ def _get_value_type(self, func_name: str, key: str, tools: List[Tool]) -> str: tools: List of available tools Returns: - Type string: 'string', 'number', or 'object' + Type string: 'string', 'number', 'object', 'array', or 'boolean' """ arg_type = get_argument_type(func_name, key, tools) if arg_type: return arg_type - # Auto-detect type from value (best effort) - first_chars = ( - self._current_value.strip()[:10] - if self._current_value and self._current_value.strip() - else "" - ) - if first_chars: - first_char = first_chars[0] + # Improved auto-detection type from value (best effort) + value_content = self._current_value.strip() if self._current_value else "" + + if not value_content: + return "string" + + # Try to parse as valid JSON first + try: + parsed = json.loads(value_content) + if isinstance(parsed, dict): + return "object" + elif isinstance(parsed, list): + return "array" + elif isinstance(parsed, bool): + return "boolean" + elif isinstance(parsed, (int, float)): + return "number" + # For string values, check if they look like numbers + elif isinstance(parsed, str): + if parsed.isdigit() or ( + parsed.startswith("-") and parsed[1:].isdigit() + ): + return "number" + return "string" + except json.JSONDecodeError: + # Not valid JSON, try heuristic detection + first_char = value_content[0] if value_content else "" + if first_char.isdigit() or first_char in ["-", "."]: return "number" elif first_char in ["{", "["]: return "object" + elif first_char in ['"', "'"]: + return "string" + # Default to string (safest fallback) return "string" def _format_value_complete(self, value: str, value_type: str) -> str: diff --git a/python/sglang/srt/function_call/glm4_moe_detector.py b/python/sglang/srt/function_call/glm4_moe_detector.py index 3158e32166c3..0761e24e7cba 100644 --- a/python/sglang/srt/function_call/glm4_moe_detector.py +++ b/python/sglang/srt/function_call/glm4_moe_detector.py @@ -2,16 +2,52 @@ import json import logging import re -from typing import List +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector -from sglang.srt.function_call.core_types import StreamingParseResult, _GetInfoFunc +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + ToolCallItem, + _GetInfoFunc, +) +from sglang.srt.function_call.utils import infer_type_from_json_schema logger = logging.getLogger(__name__) -def get_argument_type(func_name: str, arg_key: str, defined_tools: list): +class StreamState(str, Enum): + """State machine states for XML to JSON streaming conversion.""" + + INIT = "INIT" + BETWEEN = "BETWEEN" + IN_KEY = "IN_KEY" + WAITING_VALUE = "WAITING_VALUE" + IN_VALUE = "IN_VALUE" + + +def get_argument_type( + func_name: str, arg_key: str, defined_tools: List[Tool] +) -> Optional[str]: + """Get the expected type of a function argument from tool definitions. + + Supports complex JSON Schema definitions including: + - Direct type field (including type arrays) + - anyOf/oneOf: parameter can be any of multiple types + - enum: parameter must be one of enum values + - allOf: parameter must satisfy all type definitions + - properties: inferred as object type + - items: inferred as array type + + Args: + func_name: Name of the function/tool + arg_key: Name of the argument + defined_tools: List of available tools + + Returns: + The type string (e.g., 'string', 'number', 'object') or None if not found + """ name2tool = {tool.function.name: tool for tool in defined_tools} if func_name not in name2tool: return None @@ -21,35 +57,95 @@ def get_argument_type(func_name: str, arg_key: str, defined_tools: list): properties = {} if arg_key not in properties: return None - return properties[arg_key].get("type", None) + # Use new type inference function for complex JSON Schema support + return infer_type_from_json_schema(properties[arg_key]) + + +def _convert_to_number(value: str) -> Any: + """Convert string to appropriate number type (int or float). + + Args: + value: String value to convert + + Returns: + Converted number or original string if conversion fails + """ + try: + if "." in value or "e" in value.lower(): + return float(value) + else: + return int(value) + except (ValueError, AttributeError): + return value + + +def parse_arguments( + json_value: str, arg_type: Optional[str] = None +) -> Tuple[Any, bool]: + """Parse argument value with multiple fallback strategies. + + Args: + json_value: Raw string value to parse + arg_type: Expected type hint ('string', 'number', 'object', etc.) -def parse_arguments(json_value): + Returns: + Tuple of (parsed_value, is_valid_json) + """ + # Strategy 1: Direct JSON parsing try: parsed_value = json.loads(json_value) + + # Type coercion for number type + if arg_type == "number" and isinstance(parsed_value, str): + parsed_value = _convert_to_number(parsed_value) + return parsed_value, True - except: - # If that fails, try wrapping it to unescape JSON characters - try: - # Wrap the value as a JSON string field - wrapped = json.loads('{"tmp": "' + json_value + '"}') - # parse the unescaped value - parsed_value = json.loads(wrapped["tmp"]) - return parsed_value, True - except: - # Final fallback to ast.literal_eval - try: - parsed_value = ast.literal_eval(json_value) - return parsed_value, True - except: - return json_value, False + except (json.JSONDecodeError, ValueError): + pass + + # Strategy 2: Unescape and parse + try: + wrapped = json.loads('{"tmp": "' + json_value + '"}') + parsed_value = json.loads(wrapped["tmp"]) + + if arg_type == "number" and isinstance(parsed_value, str): + parsed_value = _convert_to_number(parsed_value) + + return parsed_value, True + except (json.JSONDecodeError, ValueError, KeyError): + pass + + # Strategy 3: ast.literal_eval + try: + parsed_value = ast.literal_eval(json_value) + return parsed_value, True + except (ValueError, SyntaxError): + pass + + # Strategy 4: Treat as string + try: + quoted_value = json.dumps(str(json_value)) + return json.loads(quoted_value), True + except (json.JSONDecodeError, ValueError): + return json_value, False class Glm4MoeDetector(BaseFormatDetector): """ Detector for GLM-4.5 and GLM-4.6 models. - Assumes function call format: - get_weather\ncity\n北京\ndate\n2024-06-27\n\nget_weather\ncity\n上海\ndate\n2024-06-27\n + Assumes function call format (with actual newlines): + get_weather + city + 北京 + date + 2024-06-27 + + + Or with literal \n characters (escaped as \\n in the output): + get_weather\ncity\n北京\n + + Uses a streaming state machine to convert XML to JSON incrementally for maximum speed. """ def __init__(self): @@ -64,6 +160,23 @@ def __init__(self): r"(.*?)(?:\\n|\s)*(.*?)", re.DOTALL, ) + self._last_arguments = "" + self.current_tool_id = -1 + self.current_tool_name_sent = False + self._streamed_raw_length = 0 + self._reset_streaming_state() + + def _reset_streaming_state(self) -> None: + """Reset the streaming state machine for a new tool call.""" + self._stream_state = StreamState.INIT + self._current_key = "" + self._current_value = "" + self._xml_tag_buffer = "" + self._is_first_param = True + self._value_started = False + self._cached_value_type: Optional[str] = ( + None # Cache the value type for consistency + ) def has_tool_call(self, text: str) -> bool: """Check if the text contains a glm-4.5 / glm-4.6 format tool call.""" @@ -92,23 +205,219 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult func_name = func_detail.group(1) if func_detail.group(1) else "" func_args = func_detail.group(2) if func_detail.group(2) else "" pairs = self.func_arg_regex.findall(func_args) - arguments = {} - for arg_key, arg_value in pairs: - arg_key = arg_key.strip() - arg_value = arg_value.strip() - arg_type = get_argument_type(func_name, arg_key, tools) - if arg_type != "string": - arg_value, is_good_json = parse_arguments(arg_value) - arguments[arg_key] = arg_value + + # Parse arguments using shared method + arguments = self._parse_argument_pairs(pairs, func_name, tools) + # construct match_result for parse_base_json match_result = {"name": func_name, "parameters": arguments} calls.extend(self.parse_base_json(match_result, tools)) return StreamingParseResult(normal_text=normal_text, calls=calls) except Exception as e: - logger.error(f"Error in detect_and_parse: {e}") + logger.error(f"Error in detect_and_parse: {e}", exc_info=True) # return the normal text if parsing fails return StreamingParseResult(normal_text=text) + def _get_value_type(self, func_name: str, key: str, tools: List[Tool]) -> str: + """Get parameter type from tool definition, with fallback to auto-detection. + + Args: + func_name: Name of the function + key: Parameter name + tools: List of available tools + + Returns: + Type string: 'string', 'number', 'object', 'array', or 'boolean' + """ + arg_type = get_argument_type(func_name, key, tools) + if arg_type: + return arg_type + + # Improved auto-detection type from value (best effort) + value_content = self._current_value.strip() if self._current_value else "" + + if not value_content: + return "string" + + # Try to parse as valid JSON first + try: + parsed = json.loads(value_content) + if isinstance(parsed, dict): + return "object" + elif isinstance(parsed, list): + return "array" + elif isinstance(parsed, bool): + return "boolean" + elif isinstance(parsed, (int, float)): + return "number" + # For string values, check if they look like numbers + elif isinstance(parsed, str): + if parsed.isdigit() or ( + parsed.startswith("-") and parsed[1:].isdigit() + ): + return "number" + return "string" + except json.JSONDecodeError: + # Not valid JSON, try heuristic detection + first_char = value_content[0] if value_content else "" + + if first_char.isdigit() or first_char in ["-", "."]: + return "number" + elif first_char in ["{", "["]: + return "object" + elif first_char in ['"', "'"]: + return "string" + + # Default to string (safest fallback) + return "string" + + def _format_value_complete(self, value: str, value_type: str) -> str: + """Format complete value based on type. + + Args: + value: Raw value string + value_type: Expected type ('string', 'number', 'object') + + Returns: + Properly formatted JSON value string + """ + if value_type == "string": + # Ensure proper JSON string formatting with quotes + return json.dumps(value, ensure_ascii=False) + elif value_type == "number": + try: + num = _convert_to_number(value.strip()) + return str(num) + except (ValueError, AttributeError): + # Fallback to string if not a valid number + logger.warning( + f"Failed to parse '{value}' as number, treating as string" + ) + return json.dumps(str(value), ensure_ascii=False) + else: + # For object/array types, return as-is (should already be valid JSON) + return value + + def _process_xml_to_json_streaming( + self, raw_increment: str, func_name: str, tools: List[Tool] + ) -> str: + """Convert XML increment to JSON streaming output using state machine. + + This method processes XML fragments character by character and converts them + to JSON format incrementally. It maintains state across calls to handle + partial XML tags and values. + + Args: + raw_increment: New XML content to process + func_name: Name of the function being called + tools: List of available tools for type inference + + Returns: + JSON string increment to append to the output + """ + json_output = "" + + for char in raw_increment: + self._xml_tag_buffer += char + + if self._stream_state in [StreamState.INIT, StreamState.BETWEEN]: + if self._xml_tag_buffer.endswith(""): + self._stream_state = StreamState.IN_KEY + self._current_key = "" + self._xml_tag_buffer = "" + json_output += "{" if self._is_first_param else ", " + self._is_first_param = False + + elif self._stream_state == StreamState.IN_KEY: + if self._xml_tag_buffer.endswith(""): + self._current_key = self._xml_tag_buffer[:-10].strip() + self._xml_tag_buffer = "" + self._stream_state = StreamState.WAITING_VALUE + json_output += ( + json.dumps(self._current_key, ensure_ascii=False) + ": " + ) + + elif self._stream_state == StreamState.WAITING_VALUE: + if self._xml_tag_buffer.endswith(""): + self._stream_state = StreamState.IN_VALUE + self._current_value = "" + self._xml_tag_buffer = "" + self._value_started = False + # Determine and cache the value type at the start + self._cached_value_type = self._get_value_type( + func_name, self._current_key, tools + ) + + elif self._stream_state == StreamState.IN_VALUE: + if self._xml_tag_buffer.endswith(""): + final_value = self._xml_tag_buffer[:-12] + self._current_value += final_value + + # Use cached value type for consistency + value_type = self._cached_value_type or "string" + + if self._value_started: + # Output any remaining content + if final_value: + if value_type == "string": + json_output += json.dumps( + final_value, ensure_ascii=False + )[1:-1] + else: + json_output += final_value + # Always output closing quote for string type when value was started + if value_type == "string": + json_output += '"' + else: + # Value was never started (empty or complete in one chunk) + json_output += self._format_value_complete( + self._current_value, value_type + ) + + self._xml_tag_buffer = "" + self._stream_state = StreamState.BETWEEN + self._current_value = "" + self._value_started = False + self._cached_value_type = None # Reset cached type + else: + closing_tag = "" + is_potential_closing = len(self._xml_tag_buffer) <= len( + closing_tag + ) and closing_tag.startswith(self._xml_tag_buffer) + + if not is_potential_closing: + content = self._xml_tag_buffer + # Use cached value type for consistency + value_type = self._cached_value_type or "string" + + if value_type == "string": + if not self._value_started: + json_output += '"' + self._value_started = True + if content: + json_output += json.dumps(content, ensure_ascii=False)[ + 1:-1 + ] + self._current_value += content + self._xml_tag_buffer = "" + elif value_type == "number": + if content: + if not self._value_started: + self._value_started = True + json_output += content + self._current_value += content + self._xml_tag_buffer = "" + else: + # For object/array types, output as-is + if content: + if not self._value_started: + self._value_started = True + json_output += content + self._current_value += content + self._xml_tag_buffer = "" + + return json_output + def parse_streaming_increment( self, new_text: str, tools: List[Tool] ) -> StreamingParseResult: diff --git a/python/sglang/srt/function_call/utils.py b/python/sglang/srt/function_call/utils.py index d85e5e6c0307..567ca583bb8f 100644 --- a/python/sglang/srt/function_call/utils.py +++ b/python/sglang/srt/function_call/utils.py @@ -1,6 +1,6 @@ from json import JSONDecodeError, JSONDecoder from json.decoder import WHITESPACE -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import orjson import partial_json_parser @@ -101,6 +101,109 @@ def _get_tool_schema(tool: Tool) -> dict: } +def infer_type_from_json_schema(schema: Dict[str, Any]) -> Optional[str]: + """ + Infer the primary type of a parameter from JSON Schema. + + Supports complex JSON Schema structures including: + - Direct type field (including type arrays) + - anyOf/oneOf: parameter can be any of multiple types + - enum: parameter must be one of enum values + - allOf: parameter must satisfy all type definitions + - properties: inferred as object type + - items: inferred as array type + + Args: + schema: JSON Schema definition + + Returns: + Inferred type ('string', 'number', 'object', 'array', etc.) or None + """ + if not isinstance(schema, dict): + return None + + # Priority 1: Direct type field (including type arrays) + if "type" in schema: + type_value = schema["type"] + if isinstance(type_value, str): + return type_value + elif isinstance(type_value, list) and type_value: + # Handle type arrays: return first non-null type + non_null_types = [t for t in type_value if t != "null"] + if non_null_types: + return non_null_types[0] + return "string" # If only null, default to string + + # Priority 2: Handle anyOf/oneOf + if "anyOf" in schema or "oneOf" in schema: + schemas = schema.get("anyOf") or schema.get("oneOf") + types = [] + + if isinstance(schemas, list): + for sub_schema in schemas: + inferred_type = infer_type_from_json_schema(sub_schema) + if inferred_type: + types.append(inferred_type) + + if types: + # If all types are the same, return unified type + if len(set(types)) == 1: + return types[0] + # When types differ, prioritize string (safest) + if "string" in types: + return "string" + # Otherwise return first type + return types[0] + + # Priority 3: Handle enum (infer type from enum values) + if "enum" in schema and isinstance(schema["enum"], list): + if not schema["enum"]: + return "string" + + # Infer type from enum values + enum_types = set() + for value in schema["enum"]: + if value is None: + enum_types.add("null") + elif isinstance(value, bool): + enum_types.add("boolean") + elif isinstance(value, int): + enum_types.add("integer") + elif isinstance(value, float): + enum_types.add("number") + elif isinstance(value, str): + enum_types.add("string") + elif isinstance(value, list): + enum_types.add("array") + elif isinstance(value, dict): + enum_types.add("object") + + # If type is uniform, return that type + if len(enum_types) == 1: + return enum_types.pop() + # Mixed types, prioritize string + return "string" + + # Priority 4: Handle allOf (must satisfy all types) + if "allOf" in schema and isinstance(schema["allOf"], list): + schemas = schema["allOf"] + for sub_schema in schemas: + inferred_type = infer_type_from_json_schema(sub_schema) + if inferred_type and inferred_type != "string": + return inferred_type + return "string" + + # Priority 5: Infer object type + if "properties" in schema: + return "object" + + # Priority 6: Infer array type + if "items" in schema: + return "array" + + return None + + def get_json_schema_constraint( tools: List[Tool], tool_choice: Union[ToolChoice, Literal["required"]] ) -> Optional[dict]: diff --git a/test/registered/function_call/test_glm47_moe_detector.py b/test/registered/function_call/test_glm47_moe_detector.py index a046964064ce..fdd06e9ce7df 100644 --- a/test/registered/function_call/test_glm47_moe_detector.py +++ b/test/registered/function_call/test_glm47_moe_detector.py @@ -3,7 +3,11 @@ from sglang.srt.entrypoints.openai.protocol import Function, Tool from sglang.srt.function_call.core_types import StreamingParseResult -from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector +from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector +from sglang.srt.function_call.glm47_moe_detector import ( + Glm47MoeDetector, + get_argument_type, +) from sglang.test.ci.ci_register import register_cpu_ci register_cpu_ci(1.0, "default") @@ -1172,5 +1176,676 @@ def test_streamed_raw_length_multiple_empty_returns(self): ) -if __name__ == "__main__": - unittest.main() +class TestGlm4ComplexJsonSchema(unittest.TestCase): + """Test complex JSON Schema type inference for GLM function call parsers.""" + + def setUp(self): + """Set up test tools with complex JSON schemas.""" + self.tools_with_complex_schema = [ + Tool( + type="function", + function=Function( + name="search", + description="Search for information", + parameters={ + "type": "object", + "properties": { + "query": { + "description": "Search query, can be a string or a complex object", + "anyOf": [ + {"type": "string"}, + { + "type": "object", + "properties": { + "text": {"type": "string"}, + "filters": {"type": "object"}, + }, + }, + ], + }, + "priority": {"enum": ["low", "medium", "high"]}, + "options": { + "oneOf": [{"type": "string"}, {"type": "number"}] + }, + "config": { + "allOf": [ + {"type": "object"}, + {"properties": {"timeout": {"type": "number"}}}, + ] + }, + "tags": {"type": ["string", "null"]}, + "data": { + "type": "object", + "properties": { + "nested": { + "anyOf": [ + {"type": "string"}, + { + "type": "object", + "properties": { + "value": {"type": "string"} + }, + }, + ] + } + }, + }, + }, + "required": ["query"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "Location to get weather for", + }, + "unit": { + "type": "string", + "description": "Temperature unit", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + ), + ), + ] + self.glm4_detector = Glm4MoeDetector() + self.glm47_detector = Glm47MoeDetector() + + def test_get_argument_type_simple_type(self): + """Test that get_argument_type correctly handles simple type fields.""" + result = get_argument_type( + "get_weather", "location", self.tools_with_complex_schema + ) + self.assertEqual(result, "string") + + def test_get_argument_type_enum_type(self): + """Test that get_argument_type correctly identifies enum as string type.""" + result = get_argument_type( + "get_weather", "unit", self.tools_with_complex_schema + ) + # Current implementation returns the direct type field, which is "string" for the enum parameter + # But it doesn't handle enum-only schemas properly (without type field) + self.assertEqual(result, "string") + + def test_get_argument_type_anyof_type(self): + """Test that get_argument_type correctly handles anyOf type fields.""" + result = get_argument_type("search", "query", self.tools_with_complex_schema) + # anyOf with [{"type": "string"}, {"type": "object", ...}] should return "string" + self.assertEqual(result, "string") # Returns first common type + + def test_get_argument_type_oneof_type(self): + """Test that get_argument_type correctly handles oneOf type fields.""" + result = get_argument_type("search", "options", self.tools_with_complex_schema) + # oneOf with [{"type": "string"}, {"type": "number"}] should return "string" (prioritizes string) + self.assertEqual(result, "string") + + def test_get_argument_type_allof_type(self): + """Test that get_argument_type correctly handles allOf type fields.""" + result = get_argument_type("search", "config", self.tools_with_complex_schema) + # allOf with [{"type": "object"}, ...] should return "object" + self.assertEqual(result, "object") + + def test_get_argument_type_type_array(self): + """Test that get_argument_type correctly handles type arrays.""" + result = get_argument_type("search", "tags", self.tools_with_complex_schema) + # Type arrays should return the first non-null type + self.assertEqual( + result, "string" + ) # ["string", "null"] -> "string" (non-null type) + + def test_glm4_detector_with_complex_schema_anyof(self): + """Test GLM4 detector with anyOf schema - should demonstrate current issues.""" + # This test shows the current behavior with complex schemas + text = ( + "search\n" + "query\nHello world\n" + "priority\nmedium\n" + "" + ) + result = self.glm4_detector.detect_and_parse( + text, self.tools_with_complex_schema + ) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "search") + + # Parse parameters to check if they are correctly handled + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["query"], "Hello world") + self.assertEqual(params["priority"], "medium") + + def test_glm47_detector_with_complex_schema_anyof(self): + """Test GLM47 detector with anyOf schema - should demonstrate current issues.""" + # This test shows the current behavior with complex schemas + text = ( + "search" + "queryHello world" + "prioritymedium" + "" + ) + result = self.glm47_detector.detect_and_parse( + text, self.tools_with_complex_schema + ) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "search") + + # Parse parameters to check if they are correctly handled + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["query"], "Hello world") + self.assertEqual(params["priority"], "medium") + + def test_glm4_detector_with_enum_values(self): + """Test GLM4 detector with enum values in complex schema.""" + text = ( + "search\n" + "query\ntest query\n" + "priority\nhigh\n" + "" + ) + result = self.glm4_detector.detect_and_parse( + text, self.tools_with_complex_schema + ) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "search") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["query"], "test query") + self.assertEqual(params["priority"], "high") + + def test_glm47_detector_with_enum_values(self): + """Test GLM47 detector with enum values in complex schema.""" + text = ( + "search" + "querytest query" + "priorityhigh" + "" + ) + result = self.glm47_detector.detect_and_parse( + text, self.tools_with_complex_schema + ) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "search") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["query"], "test query") + self.assertEqual(params["priority"], "high") + + def test_glm4_detector_streaming_with_complex_schema(self): + """Test GLM4 detector streaming with complex schema.""" + chunks = [ + "search\n", + "query\nnested object\n", + "priority\nlow\n", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.glm4_detector.parse_streaming_increment( + chunk, self.tools_with_complex_schema + ) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "search") + + params = json.loads(tool_calls[0]["parameters"]) + self.assertEqual(params["query"], "nested object") + self.assertEqual(params["priority"], "low") + + def test_glm47_detector_streaming_with_complex_schema(self): + """Test GLM47 detector streaming with complex schema.""" + chunks = [ + "search", + "querynested object", + "prioritylow", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.glm47_detector.parse_streaming_increment( + chunk, self.tools_with_complex_schema + ) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "search") + + params = json.loads(tool_calls[0]["parameters"]) + self.assertEqual(params["query"], "nested object") + self.assertEqual(params["priority"], "low") + + def test_type_inference_issue_reproduction(self): + """Reproduce the issue where complex JSON schemas are not properly handled.""" + # This test demonstrates the current limitations + complex_tools = [ + Tool( + type="function", + function=Function( + name="complex_function", + parameters={ + "type": "object", + "properties": { + "complex_param": { + "anyOf": [ + {"type": "string"}, + { + "type": "object", + "properties": {"value": {"type": "string"}}, + }, + ] + }, + "enum_param": {"enum": ["option1", "option2", "option3"]}, + }, + }, + ), + ) + ] + + # Test that get_argument_type returns appropriate types for complex schemas + anyof_result = get_argument_type( + "complex_function", "complex_param", complex_tools + ) + enum_result = get_argument_type("complex_function", "enum_param", complex_tools) + + # Verify complex schema types are correctly inferred + self.assertEqual(anyof_result, "string") # anyOf prioritizes string type + self.assertEqual(enum_result, "string") # enum values are strings + + def test_expected_behavior_for_complex_schemas(self): + """Test cases that should work but currently fail - demonstrating the issue.""" + # This test shows what the behavior SHOULD be after the fix + complex_tools = [ + Tool( + type="function", + function=Function( + name="complex_function", + parameters={ + "type": "object", + "properties": { + "complex_param": { + "anyOf": [ + {"type": "string"}, + { + "type": "object", + "properties": {"value": {"type": "string"}}, + }, + ] + }, + "enum_param": {"enum": ["option1", "option2", "option3"]}, + "oneof_param": { + "oneOf": [{"type": "string"}, {"type": "number"}] + }, + "allof_param": { + "allOf": [ + {"type": "object"}, + {"properties": {"timeout": {"type": "number"}}}, + ] + }, + }, + }, + ), + ) + ] + + # These assertions represent the EXPECTED behavior after implementing RFC improvements + # Currently they will fail, demonstrating the issue + anyof_result = get_argument_type( + "complex_function", "complex_param", complex_tools + ) + enum_result = get_argument_type("complex_function", "enum_param", complex_tools) + oneof_result = get_argument_type( + "complex_function", "oneof_param", complex_tools + ) + allof_result = get_argument_type( + "complex_function", "allof_param", complex_tools + ) + + # These should pass after implementing the RFC improvements, but will currently fail + # This demonstrates the issue exists + self.assertIsNotNone( + anyof_result, "anyOf should return a type after RFC implementation" + ) + self.assertEqual( + enum_result, + "string", + "enum should return 'string' type after RFC implementation", + ) + self.assertIsNotNone( + oneof_result, "oneOf should return a type after RFC implementation" + ) + self.assertIsNotNone( + allof_result, "allOf should return a type after RFC implementation" + ) + + def test_complex_schema_type_inference_scenarios(self): + """Test various complex schema scenarios mentioned in the RFC.""" + # Create tools with different complex schema structures + complex_schema_tools = [ + Tool( + type="function", + function=Function( + name="search_complex", + parameters={ + "type": "object", + "properties": { + # anyOf example - parameter can be string or object + "query": { + "description": "Search query, can be a string or a complex object", + "anyOf": [ + {"type": "string"}, + { + "type": "object", + "properties": { + "text": {"type": "string"}, + "filters": {"type": "object"}, + }, + }, + ], + }, + # oneOf example - parameter must be one of the specified types + "priority": { + "oneOf": [{"type": "string"}, {"type": "integer"}] + }, + # enum example - parameter must be one of the enum values + "category": {"enum": ["news", "sports", "tech"]}, + # allOf example - parameter must satisfy all schemas + "config": { + "allOf": [ + {"type": "object"}, + {"properties": {"timeout": {"type": "number"}}}, + ] + }, + # Type array example + "tags": {"type": ["string", "null"]}, + }, + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_data", + parameters={ + "type": "object", + "properties": { + # Complex nested anyOf + "input": { + "anyOf": [ + {"type": "string"}, + {"type": "number"}, + { + "type": "object", + "properties": { + "type": {"type": "string"}, + "value": {}, + }, + }, + ] + } + }, + }, + ), + ), + ] + + # Test each complex type scenario + query_type = get_argument_type("search_complex", "query", complex_schema_tools) + priority_type = get_argument_type( + "search_complex", "priority", complex_schema_tools + ) + category_type = get_argument_type( + "search_complex", "category", complex_schema_tools + ) + config_type = get_argument_type( + "search_complex", "config", complex_schema_tools + ) + tags_type = get_argument_type("search_complex", "tags", complex_schema_tools) + input_type = get_argument_type("get_data", "input", complex_schema_tools) + + # All of these should return appropriate types according to RFC + self.assertEqual(query_type, "string") # anyOf: string | object -> string + self.assertEqual(priority_type, "string") # oneOf: string | integer -> string + self.assertEqual( + category_type, "string" + ) # enum: ["news", "sports", "tech"] -> string + self.assertEqual(config_type, "object") # allOf with object -> object + self.assertEqual( + tags_type, "string" + ) # type array: ["string", "null"] -> string + self.assertEqual( + input_type, "string" + ) # nested anyOf: string | number | object -> string + + def test_glm4_detector_type_handling_with_complex_schema(self): + """Test how GLM4 detector handles type inference for complex schemas in practice.""" + complex_tools = [ + Tool( + type="function", + function=Function( + name="complex_search", + parameters={ + "type": "object", + "properties": { + "query": { + "anyOf": [ + {"type": "string"}, + { + "type": "object", + "properties": {"text": {"type": "string"}}, + }, + ] + }, + "category": {"enum": ["tech", "news", "sports"]}, + }, + }, + ), + ) + ] + + # Test with string value for anyOf parameter + text = ( + "complex_search\n" + "query\ntest search\n" + "category\ntech\n" + "" + ) + result = self.glm4_detector.detect_and_parse(text, complex_tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "complex_search") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["query"], "test search") + self.assertEqual(params["category"], "tech") + + def test_glm47_detector_type_handling_with_complex_schema(self): + """Test how GLM47 detector handles type inference for complex schemas in practice.""" + complex_tools = [ + Tool( + type="function", + function=Function( + name="complex_search", + parameters={ + "type": "object", + "properties": { + "query": { + "anyOf": [ + {"type": "string"}, + { + "type": "object", + "properties": {"text": {"type": "string"}}, + }, + ] + }, + "category": {"enum": ["tech", "news", "sports"]}, + }, + }, + ), + ) + ] + + # Test with string value for anyOf parameter + text = ( + "complex_search" + "querytest search" + "categorytech" + "" + ) + result = self.glm47_detector.detect_and_parse(text, complex_tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "complex_search") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["query"], "test search") + self.assertEqual(params["category"], "tech") + + def test_streaming_with_complex_schema_type_inference(self): + """Test streaming behavior with complex schema type inference.""" + complex_tools = [ + Tool( + type="function", + function=Function( + name="stream_test", + parameters={ + "type": "object", + "properties": { + "data": { + "anyOf": [ + {"type": "string"}, + { + "type": "object", + "properties": {"value": {"type": "string"}}, + }, + ] + }, + "status": {"enum": ["active", "inactive"]}, + }, + }, + ), + ) + ] + + # Test GLM4 detector streaming + chunks = [ + "stream_test\n", + "data\nnested data\n", + "status\nactive\n", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.glm4_detector.parse_streaming_increment(chunk, complex_tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "stream_test") + + params = json.loads(tool_calls[0]["parameters"]) + self.assertEqual(params["data"], "nested data") + self.assertEqual(params["status"], "active") + + def test_streaming_with_complex_schema_type_inference_glm47(self): + """Test GLM47 streaming behavior with complex schema type inference.""" + complex_tools = [ + Tool( + type="function", + function=Function( + name="stream_test", + parameters={ + "type": "object", + "properties": { + "data": { + "anyOf": [ + {"type": "string"}, + { + "type": "object", + "properties": {"value": {"type": "string"}}, + }, + ] + }, + "status": {"enum": ["active", "inactive"]}, + }, + }, + ), + ) + ] + + # Test GLM47 detector streaming + chunks = [ + "stream_test", + "datanested data", + "statusactive", + "", + ] + tool_calls = [] + for chunk in chunks: + result = self.glm47_detector.parse_streaming_increment(chunk, complex_tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "stream_test") + + params = json.loads(tool_calls[0]["parameters"]) + self.assertEqual(params["data"], "nested data") + self.assertEqual(params["status"], "active") + + if __name__ == "__main__": + unittest.main()