From 2da959604d72eced8884f76a9f173aa64f746488 Mon Sep 17 00:00:00 2001
From: Yuxuan Zhang <2448370773@qq.com>
Date: Sat, 20 Dec 2025 12:30:44 +0800
Subject: [PATCH 1/3] [GLM-4.7] GLM-4.7 Tool Parser and Doc Update (#15333)
---
docs/advanced_features/server_arguments.md | 2 +-
docs/basic_usage/glm45.md | 6 +-
.../srt/function_call/function_call_parser.py | 2 +
.../srt/function_call/glm47_moe_detector.py | 584 ++++++++++++++++++
.../srt/function_call/glm4_moe_detector.py | 7 +-
python/sglang/srt/models/glm4_moe.py | 2 +-
.../test_function_call_parser.py | 275 +++++++++
7 files changed, 872 insertions(+), 6 deletions(-)
create mode 100644 python/sglang/srt/function_call/glm47_moe_detector.py
diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md
index e90b598d932e..19c2ecb9dfe0 100644
--- a/docs/advanced_features/server_arguments.md
+++ b/docs/advanced_features/server_arguments.md
@@ -198,7 +198,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--file-storage-path` | The path of the file storage in backend. | `sglang_storage` | Type: str |
| `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | `False` | bool flag (set to enable) |
| `--reasoning-parser` | Specify the parser for reasoning models. Supported parsers: [deepseek-r1, deepseek-v3, glm45, gpt-oss, kimi, qwen3, qwen3-thinking, step3]. | `None` | `deepseek-r1`, `deepseek-v3`, `glm45`, `gpt-oss`, `kimi`, `qwen3`, `qwen3-thinking`, `step3` |
-| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Supported parsers: [deepseekv3, deepseekv31, glm, glm45, gpt-oss, kimi_k2, llama3, mistral, pythonic, qwen, qwen25, qwen3_coder, step3]. | `None` | `deepseekv3`, `deepseekv31`, `glm`, `glm45`, `gpt-oss`, `kimi_k2`, `llama3`, `mistral`, `pythonic`, `qwen`, `qwen25`, `qwen3_coder`, `step3` |
+| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Supported parsers: [deepseekv3, deepseekv31, glm, glm45, glm47, gpt-oss, kimi_k2, llama3, mistral, pythonic, qwen, qwen25, qwen3_coder, step3]. | `None` | `deepseekv3`, `deepseekv31`, `glm`, `glm45`, `glm47`, `gpt-oss`, `kimi_k2`, `llama3`, `mistral`, `pythonic`, `qwen`, `qwen25`, `qwen3_coder`, `step3` |
| `--sampling-defaults` | Where to get default sampling parameters. 'openai' uses SGLang/OpenAI defaults (temperature=1.0, top_p=1.0, etc.). 'model' uses the model's generation_config.json to get the recommended sampling parameters if available. Default is 'model'. | `model` | `openai`, `model` |
| `--tool-server` | Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used. | `None` | Type: str |
diff --git a/docs/basic_usage/glm45.md b/docs/basic_usage/glm45.md
index d18b0a68d335..b68f984f90da 100644
--- a/docs/basic_usage/glm45.md
+++ b/docs/basic_usage/glm45.md
@@ -1,4 +1,4 @@
-## Launch GLM-4.5 / GLM-4.6 with SGLang
+## Launch GLM-4.5 / GLM-4.6 / GLM-4.7 with SGLang
To serve GLM-4.5 / GLM-4.6 FP8 models on 8xH100/H200 GPUs:
@@ -35,7 +35,9 @@ python3 -m sglang.launch_server \
--enable-custom-logit-processor
```
-### Thinking Budget for GLM-4.5 / GLM-4.6
+**Note**: For GLM-4.7, `--tool-call-parser` should be set to `glm47`, for GLM-4.5 and GLM-4.6, it should be set to `glm45`.
+
+### Thinking Budget
In SGLang, we can implement thinking budget with `CustomLogitProcessor`.
diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py
index 3700ea067af4..0592faf81832 100644
--- a/python/sglang/srt/function_call/function_call_parser.py
+++ b/python/sglang/srt/function_call/function_call_parser.py
@@ -15,6 +15,7 @@
from sglang.srt.function_call.deepseekv31_detector import DeepSeekV31Detector
from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
+from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector
from sglang.srt.function_call.gpt_oss_detector import GptOssDetector
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
@@ -45,6 +46,7 @@ class FunctionCallParser:
"deepseekv32": DeepSeekV32Detector,
"glm": Glm4MoeDetector,
"glm45": Glm4MoeDetector,
+ "glm47": Glm47MoeDetector,
"gpt-oss": GptOssDetector,
"longcat": LongCatDetector,
"kimi_k2": KimiK2Detector,
diff --git a/python/sglang/srt/function_call/glm47_moe_detector.py b/python/sglang/srt/function_call/glm47_moe_detector.py
new file mode 100644
index 000000000000..726737c5788a
--- /dev/null
+++ b/python/sglang/srt/function_call/glm47_moe_detector.py
@@ -0,0 +1,584 @@
+import ast
+import json
+import logging
+import re
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple
+
+from sglang.srt.entrypoints.openai.protocol import Tool
+from sglang.srt.function_call.base_format_detector import BaseFormatDetector
+from sglang.srt.function_call.core_types import (
+ StreamingParseResult,
+ ToolCallItem,
+ _GetInfoFunc,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class StreamState(str, Enum):
+ """State machine states for XML to JSON streaming conversion."""
+
+ INIT = "INIT"
+ BETWEEN = "BETWEEN"
+ IN_KEY = "IN_KEY"
+ WAITING_VALUE = "WAITING_VALUE"
+ IN_VALUE = "IN_VALUE"
+
+
+def get_argument_type(
+ func_name: str, arg_key: str, defined_tools: List[Tool]
+) -> Optional[str]:
+ """Get the expected type of a function argument from tool definitions.
+
+ Args:
+ func_name: Name of the function/tool
+ arg_key: Name of the argument
+ defined_tools: List of available tools
+
+ Returns:
+ The type string (e.g., 'string', 'number', 'object') or None if not found
+ """
+ name2tool = {tool.function.name: tool for tool in defined_tools}
+ if func_name not in name2tool:
+ return None
+ tool = name2tool[func_name]
+ properties = (tool.function.parameters or {}).get("properties", {})
+ if not isinstance(properties, dict):
+ properties = {}
+ if arg_key not in properties:
+ return None
+ return properties[arg_key].get("type", None)
+
+
+def _convert_to_number(value: str) -> Any:
+ """Convert string to appropriate number type (int or float).
+
+ Args:
+ value: String value to convert
+
+ Returns:
+ Converted number or original string if conversion fails
+ """
+ try:
+ if "." in value or "e" in value.lower():
+ return float(value)
+ else:
+ return int(value)
+ except (ValueError, AttributeError):
+ return value
+
+
+def parse_arguments(
+ json_value: str, arg_type: Optional[str] = None
+) -> Tuple[Any, bool]:
+ """Parse argument value with multiple fallback strategies.
+
+ Args:
+ json_value: Raw string value to parse
+ arg_type: Expected type hint ('string', 'number', 'object', etc.)
+
+ Returns:
+ Tuple of (parsed_value, is_valid_json)
+ """
+ # Strategy 1: Direct JSON parsing
+ try:
+ parsed_value = json.loads(json_value)
+
+ # Type coercion for number type
+ if arg_type == "number" and isinstance(parsed_value, str):
+ parsed_value = _convert_to_number(parsed_value)
+
+ return parsed_value, True
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ # Strategy 2: Unescape and parse
+ try:
+ wrapped = json.loads('{"tmp": "' + json_value + '"}')
+ parsed_value = json.loads(wrapped["tmp"])
+
+ if arg_type == "number" and isinstance(parsed_value, str):
+ parsed_value = _convert_to_number(parsed_value)
+
+ return parsed_value, True
+ except (json.JSONDecodeError, ValueError, KeyError):
+ pass
+
+ # Strategy 3: ast.literal_eval
+ try:
+ parsed_value = ast.literal_eval(json_value)
+ return parsed_value, True
+ except (ValueError, SyntaxError):
+ pass
+
+ # Strategy 4: Treat as string
+ try:
+ quoted_value = json.dumps(str(json_value))
+ return json.loads(quoted_value), True
+ except (json.JSONDecodeError, ValueError):
+ return json_value, False
+
+
+class Glm47MoeDetector(BaseFormatDetector):
+ """
+ Detector for GLM-4.7 and GLM-5 models.
+ Assumes function call format:
+ get_weathercity北京date2024-06-27get_weathercity上海date2024-06-27
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.bot_token = ""
+ self.eot_token = ""
+ self.func_call_regex = r".*?"
+ self.func_detail_regex = re.compile(
+ r"(.*?)(.*?)?", re.DOTALL
+ )
+ self.func_arg_regex = re.compile(
+ r"(.*?)(?:\\n|\s)*(.*?)",
+ re.DOTALL,
+ )
+ self._last_arguments = ""
+ self.current_tool_id = -1
+ self.current_tool_name_sent = False
+ self._streamed_raw_length = 0
+ self._reset_streaming_state()
+
+ def _reset_streaming_state(self) -> None:
+ """Reset the streaming state machine for a new tool call."""
+ self._stream_state = StreamState.INIT
+ self._current_key = ""
+ self._current_value = ""
+ self._xml_tag_buffer = ""
+ self._is_first_param = True
+ self._value_started = False
+ self._cached_value_type: Optional[str] = (
+ None # Cache the value type for consistency
+ )
+
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a glm-4.5 / glm-4.6 format tool call."""
+ return self.bot_token in text
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ One-time parsing: Detects and parses tool calls in the provided text.
+
+ :param text: The complete text to parse.
+ :param tools: List of available tools.
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
+ """
+ idx = text.find(self.bot_token)
+ normal_text = text[:idx].strip() if idx != -1 else text
+ if self.bot_token not in text:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
+ match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
+ calls = []
+ try:
+ for match_result in match_result_list:
+ # Get function name
+ func_detail = self.func_detail_regex.search(match_result)
+ func_name = func_detail.group(1)
+ func_args = func_detail.group(2)
+ arguments = {}
+ if func_args:
+ pairs = self.func_arg_regex.findall(func_args)
+ # Parse arguments using shared method
+ arguments = self._parse_argument_pairs(pairs, func_name, tools)
+
+ # construct match_result for parse_base_json
+ match_result = {"name": func_name, "parameters": arguments}
+ calls.extend(self.parse_base_json(match_result, tools))
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+ except Exception as e:
+ logger.error(f"Error in detect_and_parse: {e}", exc_info=True)
+ # return the normal text if parsing fails
+ return StreamingParseResult(normal_text=text)
+
+ def _get_value_type(self, func_name: str, key: str, tools: List[Tool]) -> str:
+ """Get parameter type from tool definition, with fallback to auto-detection.
+
+ Args:
+ func_name: Name of the function
+ key: Parameter name
+ tools: List of available tools
+
+ Returns:
+ Type string: 'string', 'number', or 'object'
+ """
+ arg_type = get_argument_type(func_name, key, tools)
+ if arg_type:
+ return arg_type
+
+ # Auto-detect type from value (best effort)
+ first_chars = self._current_value.strip()[:10] if self._current_value else ""
+ if first_chars:
+ first_char = first_chars[0]
+ if first_char.isdigit() or first_char in ["-", "."]:
+ return "number"
+ elif first_char in ["{", "["]:
+ return "object"
+
+ return "string"
+
+ def _format_value_complete(self, value: str, value_type: str) -> str:
+ """Format complete value based on type.
+
+ Args:
+ value: Raw value string
+ value_type: Expected type ('string', 'number', 'object')
+
+ Returns:
+ Properly formatted JSON value string
+ """
+ if value_type == "string":
+ # Ensure proper JSON string formatting with quotes
+ return json.dumps(value, ensure_ascii=False)
+ elif value_type == "number":
+ try:
+ num = _convert_to_number(value.strip())
+ return str(num)
+ except (ValueError, AttributeError):
+ # Fallback to string if not a valid number
+ logger.warning(
+ f"Failed to parse '{value}' as number, treating as string"
+ )
+ return json.dumps(str(value), ensure_ascii=False)
+ else:
+ # For object/array types, return as-is (should already be valid JSON)
+ return value
+
+ def _process_xml_to_json_streaming(
+ self, raw_increment: str, func_name: str, tools: List[Tool]
+ ) -> str:
+ """Convert XML increment to JSON streaming output using state machine.
+
+ This method processes XML fragments character by character and converts them
+ to JSON format incrementally. It maintains state across calls to handle
+ partial XML tags and values.
+
+ Args:
+ raw_increment: New XML content to process
+ func_name: Name of the function being called
+ tools: List of available tools for type inference
+
+ Returns:
+ JSON string increment to append to the output
+ """
+ json_output = ""
+
+ for char in raw_increment:
+ self._xml_tag_buffer += char
+
+ if self._stream_state in [StreamState.INIT, StreamState.BETWEEN]:
+ if self._xml_tag_buffer.endswith(""):
+ self._stream_state = StreamState.IN_KEY
+ self._current_key = ""
+ self._xml_tag_buffer = ""
+ json_output += "{" if self._is_first_param else ", "
+ self._is_first_param = False
+
+ elif self._stream_state == StreamState.IN_KEY:
+ if self._xml_tag_buffer.endswith(""):
+ self._current_key = self._xml_tag_buffer[:-10].strip()
+ self._xml_tag_buffer = ""
+ self._stream_state = StreamState.WAITING_VALUE
+ json_output += (
+ json.dumps(self._current_key, ensure_ascii=False) + ": "
+ )
+
+ elif self._stream_state == StreamState.WAITING_VALUE:
+ if self._xml_tag_buffer.endswith(""):
+ self._stream_state = StreamState.IN_VALUE
+ self._current_value = ""
+ self._xml_tag_buffer = ""
+ self._value_started = False
+ # Determine and cache the value type at the start
+ self._cached_value_type = self._get_value_type(
+ func_name, self._current_key, tools
+ )
+
+ elif self._stream_state == StreamState.IN_VALUE:
+ if self._xml_tag_buffer.endswith(""):
+ final_value = self._xml_tag_buffer[:-12]
+ self._current_value += final_value
+
+ # Use cached value type for consistency
+ value_type = self._cached_value_type or "string"
+
+ if self._value_started:
+ # Output any remaining content
+ if final_value:
+ if value_type == "string":
+ json_output += json.dumps(
+ final_value, ensure_ascii=False
+ )[1:-1]
+ else:
+ json_output += final_value
+ # Always output closing quote for string type when value was started
+ if value_type == "string":
+ json_output += '"'
+ else:
+ # Value was never started (empty or complete in one chunk)
+ json_output += self._format_value_complete(
+ self._current_value, value_type
+ )
+
+ self._xml_tag_buffer = ""
+ self._stream_state = StreamState.BETWEEN
+ self._current_value = ""
+ self._value_started = False
+ self._cached_value_type = None # Reset cached type
+ else:
+ closing_tag = ""
+ is_potential_closing = len(self._xml_tag_buffer) <= len(
+ closing_tag
+ ) and closing_tag.startswith(self._xml_tag_buffer)
+
+ if not is_potential_closing:
+ content = self._xml_tag_buffer
+ # Use cached value type for consistency
+ value_type = self._cached_value_type or "string"
+
+ if value_type == "string":
+ if not self._value_started:
+ json_output += '"'
+ self._value_started = True
+ if content:
+ json_output += json.dumps(content, ensure_ascii=False)[
+ 1:-1
+ ]
+ self._current_value += content
+ self._xml_tag_buffer = ""
+ elif value_type == "number":
+ if content:
+ if not self._value_started:
+ self._value_started = True
+ json_output += content
+ self._current_value += content
+ self._xml_tag_buffer = ""
+ else:
+ # For object/array types, output as-is
+ if content:
+ if not self._value_started:
+ self._value_started = True
+ json_output += content
+ self._current_value += content
+ self._xml_tag_buffer = ""
+
+ return json_output
+
+ def parse_streaming_increment(
+ self, new_text: str, tools: List[Tool]
+ ) -> StreamingParseResult:
+ """
+ Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format.
+ Uses a state machine to convert XML to JSON incrementally for true character-by-character streaming.
+ Outputs JSON increments immediately as XML data arrives.
+ """
+ self._buffer += new_text
+ current_text = self._buffer
+
+ # Check if we have a tool call
+ has_tool_call = self.bot_token in current_text
+
+ if not has_tool_call:
+ # Check if buffer could be the start of a tool call
+ # Keep buffer if it could be a partial match of bot_token
+ is_potential_start = any(
+ self.bot_token.startswith(current_text[-i:])
+ for i in range(1, min(len(current_text), len(self.bot_token)) + 1)
+ )
+
+ if not is_potential_start:
+ # Not a potential tool call, return as normal text
+ # Must return the entire buffer (current_text), not just new_text,
+ # because buffer may contain previously accumulated characters like '<'
+ # that turned out not to be part of a tool call
+ output_text = current_text
+ self._buffer = ""
+ if self.eot_token in output_text:
+ output_text = output_text.replace(self.eot_token, "")
+ return StreamingParseResult(normal_text=output_text)
+ else:
+ # Could be start of tool call, keep buffering
+ return StreamingParseResult(normal_text="", calls=[])
+
+ if not hasattr(self, "_tool_indices"):
+ self._tool_indices = self._get_tool_indices(tools)
+
+ calls: list[ToolCallItem] = []
+ try:
+ # Try to match a partial or complete tool call
+ partial_match = re.search(
+ pattern=r"(.*?)(.*?)?(|$)",
+ string=current_text,
+ flags=re.DOTALL,
+ )
+ if partial_match:
+ func_name = partial_match.group(1).strip()
+ func_args_raw = partial_match.group(2).strip()
+ is_tool_end = partial_match.group(3)
+
+ # Initialize state if this is the first tool call
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ self.prev_tool_call_arr = []
+ self.streamed_args_for_tool = [""]
+ self._streamed_raw_length = 0
+ self.current_tool_name_sent = False
+ self._reset_streaming_state()
+
+ # Ensure we have enough entries in our tracking arrays
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ # Send tool name first if not sent yet
+ if not self.current_tool_name_sent:
+ assert func_name, "func_name should not be empty"
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=func_name,
+ parameters="",
+ )
+ )
+ self.current_tool_name_sent = True
+ self._streamed_raw_length = 0
+ self._reset_streaming_state()
+ # Store the tool call info
+ self.prev_tool_call_arr[self.current_tool_id] = {
+ "name": func_name,
+ "arguments": {},
+ }
+ else:
+ # Process XML to JSON streaming
+ current_raw_length = len(func_args_raw)
+
+ if current_raw_length > self._streamed_raw_length:
+ # Get the new raw XML content
+ raw_increment = func_args_raw[self._streamed_raw_length :]
+
+ # Convert XML increment to JSON increment using state machine
+ json_increment = self._process_xml_to_json_streaming(
+ raw_increment, func_name, tools
+ )
+
+ if json_increment:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters=json_increment,
+ )
+ )
+ self._last_arguments += json_increment
+ self.streamed_args_for_tool[
+ self.current_tool_id
+ ] += json_increment
+
+ # Update the streamed length
+ self._streamed_raw_length = current_raw_length
+
+ if is_tool_end == self.eot_token:
+ if self._is_first_param:
+ empty_object = "{}"
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters=empty_object,
+ )
+ )
+ self._last_arguments += empty_object
+ elif not self._last_arguments.endswith("}"):
+ closing_brace = "}"
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters=closing_brace,
+ )
+ )
+ self._last_arguments += closing_brace
+ self.streamed_args_for_tool[
+ self.current_tool_id
+ ] += closing_brace
+
+ try:
+ pairs = self.func_arg_regex.findall(func_args_raw)
+ if pairs:
+ arguments = self._parse_argument_pairs(
+ pairs, func_name, tools
+ )
+ self.prev_tool_call_arr[self.current_tool_id][
+ "arguments"
+ ] = arguments
+ except Exception as e:
+ logger.debug(
+ f"Failed to parse arguments: {e}", exc_info=True
+ )
+
+ # Remove the completed tool call from buffer
+ self._buffer = current_text[partial_match.end(3) :]
+
+ result = StreamingParseResult(normal_text="", calls=calls)
+ self.current_tool_id += 1
+ self._last_arguments = ""
+ self.current_tool_name_sent = False
+ self._streamed_raw_length = 0
+ self._reset_streaming_state()
+ return result
+
+ return StreamingParseResult(normal_text="", calls=calls)
+
+ except Exception as e:
+ logger.error(f"Error in parse_streaming_increment: {e}", exc_info=True)
+ return StreamingParseResult(normal_text=current_text)
+
+ def _parse_argument_pairs(
+ self, pairs: List[Tuple[str, str]], func_name: str, tools: List[Tool]
+ ) -> Dict[str, Any]:
+ """Parse argument key-value pairs with type coercion.
+
+ Args:
+ pairs: List of (key, value) tuples from regex matching
+ func_name: Name of the function
+ tools: List of available tools
+
+ Returns:
+ Dictionary of parsed arguments
+ """
+ arguments = {}
+ for arg_key, arg_value in pairs:
+ arg_key = arg_key.strip()
+ arg_value = arg_value.strip()
+ arg_type = get_argument_type(func_name, arg_key, tools)
+ parsed_value, is_good_json = parse_arguments(arg_value, arg_type)
+
+ if arg_type == "string":
+ # Only convert to string if explicitly defined as string type
+ if isinstance(parsed_value, str):
+ arguments[arg_key] = parsed_value
+ elif isinstance(parsed_value, (dict, list)):
+ # If parsed as dict/list but schema says string, convert to JSON string
+ arguments[arg_key] = json.dumps(parsed_value, ensure_ascii=False)
+ else:
+ arguments[arg_key] = str(parsed_value)
+ elif arg_type is None:
+ # If type is not defined, keep the parsed value as-is
+ arguments[arg_key] = parsed_value if is_good_json else arg_value
+ else:
+ # For other types (number, object, array, etc.), use parsed value
+ arguments[arg_key] = parsed_value if is_good_json else arg_value
+
+ return arguments
+
+ def supports_structural_tag(self) -> bool:
+ return False
+
+ def structure_info(self) -> _GetInfoFunc:
+ raise NotImplementedError()
diff --git a/python/sglang/srt/function_call/glm4_moe_detector.py b/python/sglang/srt/function_call/glm4_moe_detector.py
index b0fc78249aca..d1e684d2d4bc 100644
--- a/python/sglang/srt/function_call/glm4_moe_detector.py
+++ b/python/sglang/srt/function_call/glm4_moe_detector.py
@@ -16,9 +16,12 @@ def get_argument_type(func_name: str, arg_key: str, defined_tools: list):
if func_name not in name2tool:
return None
tool = name2tool[func_name]
- if arg_key not in tool.function.parameters["properties"]:
+ properties = (tool.function.parameters or {}).get("properties", {})
+ if not isinstance(properties, dict):
+ properties = {}
+ if arg_key not in properties:
return None
- return tool.function.parameters["properties"][arg_key].get("type", None)
+ return properties[arg_key].get("type", None)
def parse_arguments(json_value):
diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py
index 6483e41c1c62..c27816c2c93a 100644
--- a/python/sglang/srt/models/glm4_moe.py
+++ b/python/sglang/srt/models/glm4_moe.py
@@ -12,7 +12,7 @@
# limitations under the License.
# ==============================================================================
-"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
+"""Inference-only GLM-4.5, GLM-4.6 and GLM-4.7 model compatible with HuggingFace weights"""
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py
index 939edb00196d..a395237ad968 100644
--- a/test/registered/function_call/test_function_call_parser.py
+++ b/test/registered/function_call/test_function_call_parser.py
@@ -7,6 +7,7 @@
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
+from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector
from sglang.srt.function_call.json_array_parser import JsonArrayParser
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
@@ -2387,6 +2388,280 @@ def check_single_todos(tool_result, expected):
check_single_todos(result, expected_output)
+class TestGlm47MoeDetector(unittest.TestCase):
+ def setUp(self):
+ self.tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="get_weather",
+ description="Get weather information",
+ parameters={
+ "type": "object",
+ "properties": {
+ "city": {"type": "string", "description": "City name"},
+ "date": {"type": "string", "description": "Date"},
+ },
+ "required": ["city", "date"],
+ },
+ ),
+ ),
+ ]
+ self.detector = Glm47MoeDetector()
+
+ def test_single_tool_call(self):
+ text = (
+ "get_weather"
+ "cityBeijing"
+ "date2024-06-27"
+ ""
+ )
+ result = self.detector.detect_and_parse(text, self.tools)
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "get_weather")
+ self.assertEqual(
+ result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}'
+ )
+ self.assertEqual(result.normal_text, "")
+
+ def test_multiple_tool_calls(self):
+ text = (
+ "get_weather"
+ "cityBeijing"
+ "date2024-06-27"
+ ""
+ "get_weather"
+ "cityShanghai"
+ "date2024-06-28"
+ ""
+ )
+ result = self.detector.detect_and_parse(text, self.tools)
+ self.assertEqual(len(result.calls), 2)
+ self.assertEqual(result.calls[0].name, "get_weather")
+ self.assertEqual(
+ result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}'
+ )
+ self.assertEqual(result.calls[1].name, "get_weather")
+ self.assertEqual(
+ result.calls[1].parameters, '{"city": "Shanghai", "date": "2024-06-28"}'
+ )
+ self.assertEqual(result.normal_text, "")
+
+ def test_streaming_tool_call(self):
+ """Test streaming incremental parsing of a tool call."""
+ chunks = [
+ "get_weather",
+ "cityBeijing",
+ "date2024-06-27",
+ "",
+ ]
+ tool_calls = []
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, self.tools)
+ for tool_call_chunk in result.calls:
+ if (
+ hasattr(tool_call_chunk, "tool_index")
+ and tool_call_chunk.tool_index is not None
+ ):
+ while len(tool_calls) <= tool_call_chunk.tool_index:
+ tool_calls.append({"name": "", "parameters": ""})
+ tc = tool_calls[tool_call_chunk.tool_index]
+ if tool_call_chunk.name:
+ tc["name"] = tool_call_chunk.name
+ if tool_call_chunk.parameters:
+ tc["parameters"] += tool_call_chunk.parameters
+ self.assertEqual(len(tool_calls), 1)
+ self.assertEqual(tool_calls[0]["name"], "get_weather")
+ self.assertEqual(
+ tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}'
+ )
+
+ def test_streaming_multiple_tool_calls(self):
+ """Test streaming incremental parsing of multiple tool calls."""
+ chunks = [
+ "get_weather",
+ "cityBeijing",
+ "date2024-06-27",
+ "get_weather",
+ "cityShanghai",
+ "date2024-06-28",
+ "",
+ ]
+ tool_calls = []
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, self.tools)
+ for tool_call_chunk in result.calls:
+ if (
+ hasattr(tool_call_chunk, "tool_index")
+ and tool_call_chunk.tool_index is not None
+ ):
+ while len(tool_calls) <= tool_call_chunk.tool_index:
+ tool_calls.append({"name": "", "parameters": ""})
+ tc = tool_calls[tool_call_chunk.tool_index]
+ if tool_call_chunk.name:
+ tc["name"] = tool_call_chunk.name
+ if tool_call_chunk.parameters:
+ tc["parameters"] += tool_call_chunk.parameters
+ self.assertEqual(len(tool_calls), 2)
+ self.assertEqual(tool_calls[0]["name"], "get_weather")
+ self.assertEqual(
+ tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}'
+ )
+ self.assertEqual(tool_calls[1]["name"], "get_weather")
+ self.assertEqual(
+ tool_calls[1]["parameters"], '{"city": "Shanghai", "date": "2024-06-28"}'
+ )
+
+ def test_tool_call_id(self):
+ """Test that the buffer and state are reset after a tool call is completed."""
+ chunks = [
+ "get_weather",
+ "cityBeijing",
+ "date2024-06-27",
+ "",
+ ]
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, self.tools)
+ self.assertEqual(self.detector.current_tool_id, 1)
+
+ def test_invalid_tool_call(self):
+ """Test that invalid tool calls are handled correctly."""
+ text = "invalid_funccityBeijing"
+ result = self.detector.detect_and_parse(text, self.tools)
+ self.assertEqual(len(result.calls), 0)
+
+ def test_partial_tool_call(self):
+ """Test parsing a partial tool call that spans multiple chunks."""
+ chunks = [
+ "get_weather",
+ "cityBeijing",
+ "date2024-06-27",
+ ]
+
+ tool_calls = []
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, self.tools)
+ for tool_call_chunk in result.calls:
+ if (
+ hasattr(tool_call_chunk, "tool_index")
+ and tool_call_chunk.tool_index is not None
+ ):
+ while len(tool_calls) <= tool_call_chunk.tool_index:
+ tool_calls.append({"name": "", "parameters": ""})
+ tc = tool_calls[tool_call_chunk.tool_index]
+ if tool_call_chunk.name:
+ tc["name"] = tool_call_chunk.name
+ if tool_call_chunk.parameters:
+ tc["parameters"] += tool_call_chunk.parameters
+
+ self.assertEqual(len(tool_calls), 1)
+ self.assertEqual(tool_calls[0]["name"], "get_weather")
+ self.assertEqual(
+ tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}'
+ )
+
+ def test_array_argument_with_escaped_json(self):
+ """Test that array arguments with escaped JSON are properly handled without double-escaping."""
+ # Add a tool with array parameter
+ tools_with_array = [
+ Tool(
+ type="function",
+ function=Function(
+ name="todo_write",
+ description="Write todos",
+ parameters={
+ "type": "object",
+ "properties": {
+ "todos": {
+ "type": "array",
+ "description": "The updated todo list",
+ }
+ },
+ "required": ["todos"],
+ },
+ ),
+ ),
+ ]
+
+ def check_params(result):
+ self.assertEqual(1, len(result.calls))
+ self.assertEqual("todo_write", result.calls[0].name)
+ params = json.loads(result.calls[0].parameters)
+ self.assertIsInstance(params["todos"], list)
+ self.assertEqual(4, len(params["todos"]))
+ self.assertEqual("1", params["todos"][0]["id"])
+ self.assertEqual(
+ "Check for hard-coded issues in the backend code",
+ params["todos"][0]["task"],
+ )
+ self.assertEqual("in_progress", params["todos"][0]["status"])
+ self.assertEqual("2", params["todos"][1]["id"])
+ self.assertEqual(
+ "Check for hard-coded issues in the frontend code",
+ params["todos"][1]["task"],
+ )
+ self.assertEqual("pending", params["todos"][1]["status"])
+ self.assertEqual("3", params["todos"][2]["id"])
+ self.assertEqual(
+ "Check for code violating the Single Responsibility Principle",
+ params["todos"][2]["task"],
+ )
+ self.assertEqual("pending", params["todos"][2]["status"])
+ self.assertEqual("4", params["todos"][3]["id"])
+ self.assertEqual(
+ "Generate a rectification proposal report", params["todos"][3]["task"]
+ )
+ self.assertEqual("pending", params["todos"][3]["status"])
+
+ # Simulate the raw response from GLM-4.6 model with normal and escaped JSON in XML
+ result = self.detector.detect_and_parse(
+ """todo_writetodos[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}]
+""",
+ tools_with_array,
+ )
+ check_params(result)
+ result = self.detector.detect_and_parse(
+ r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}]
+""",
+ tools_with_array,
+ )
+ check_params(result)
+
+ def check_single_todos(tool_result, expected):
+ self.assertEqual(1, len(tool_result.calls))
+ self.assertEqual("todo_write", tool_result.calls[0].name)
+ params = json.loads(tool_result.calls[0].parameters)
+ self.assertIsInstance(params["todos"], list)
+ self.assertEqual(1, len(params["todos"]))
+ self.assertEqual("1", params["todos"][0]["id"])
+ self.assertEqual(expected, params["todos"][0]["task"])
+ self.assertEqual("pending", params["todos"][0]["status"])
+
+ # Test with escaped JSON containing backslashes in content (e.g., Windows paths)
+ expected_path = r"Check file at C:\Users\test.txt"
+ result = self.detector.detect_and_parse(
+ """todo_writetodos[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""",
+ tools_with_array,
+ )
+ check_single_todos(result, expected_path)
+ result = self.detector.detect_and_parse(
+ r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""",
+ tools_with_array,
+ )
+ check_single_todos(result, expected_path)
+
+ # Should contain literal \n, not actual newline
+ expected_output = r"Print \n to see newline"
+ result = self.detector.detect_and_parse(
+ """todo_writetodos[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""",
+ tools_with_array,
+ )
+ check_single_todos(result, expected_output)
+ result = self.detector.detect_and_parse(
+ r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""",
+ tools_with_array,
+ )
+ check_single_todos(result, expected_output)
class TestJsonArrayParser(unittest.TestCase):
def setUp(self):
# Create sample tools for testing
From aa19e9a6911fe02d8f0fb4accd055c517d35fa5f Mon Sep 17 00:00:00 2001
From: Leoyzen
Date: Wed, 31 Dec 2025 06:32:35 +0800
Subject: [PATCH 2/3] Fix: Handle empty func_name and None values in GLM MoE
detectors (#15754)
Signed-off-by: Xinyuan Tong
Co-authored-by: Xinyuan Tong
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
---
.../srt/function_call/glm47_moe_detector.py | 435 ++++--
.../srt/function_call/glm4_moe_detector.py | 248 +++-
.../test_function_call_parser.py | 15 +
.../function_call/test_glm47_moe_detector.py | 1176 +++++++++++++++++
4 files changed, 1704 insertions(+), 170 deletions(-)
create mode 100644 test/registered/function_call/test_glm47_moe_detector.py
diff --git a/python/sglang/srt/function_call/glm47_moe_detector.py b/python/sglang/srt/function_call/glm47_moe_detector.py
index 726737c5788a..0cf3a1b622e6 100644
--- a/python/sglang/srt/function_call/glm47_moe_detector.py
+++ b/python/sglang/srt/function_call/glm47_moe_detector.py
@@ -40,15 +40,27 @@ def get_argument_type(
The type string (e.g., 'string', 'number', 'object') or None if not found
"""
name2tool = {tool.function.name: tool for tool in defined_tools}
- if func_name not in name2tool:
+
+ # Check if function exists
+ tool = name2tool.get(func_name)
+ if not tool:
+ return None
+
+ # Get parameters safely using getattr
+ params = getattr(tool.function, "parameters", None)
+ if not isinstance(params, dict):
return None
- tool = name2tool[func_name]
- properties = (tool.function.parameters or {}).get("properties", {})
+
+ # Navigate to the type using dict.get() for safe access
+ properties = params.get("properties")
if not isinstance(properties, dict):
- properties = {}
- if arg_key not in properties:
return None
- return properties[arg_key].get("type", None)
+
+ arg_spec = properties.get(arg_key)
+ if isinstance(arg_spec, dict):
+ return arg_spec.get("type")
+
+ return None
def _convert_to_number(value: str) -> Any:
@@ -143,6 +155,10 @@ def __init__(self):
self.current_tool_id = -1
self.current_tool_name_sent = False
self._streamed_raw_length = 0
+ self._tool_call_completed = False # Track if tool call has been completed
+ self._sent_empty_object = (
+ False # Track if empty object has been sent for no-arg functions
+ )
self._reset_streaming_state()
def _reset_streaming_state(self) -> None:
@@ -156,6 +172,8 @@ def _reset_streaming_state(self) -> None:
self._cached_value_type: Optional[str] = (
None # Cache the value type for consistency
)
+ self._tool_call_completed = False # Reset tool call completion status
+ self._sent_empty_object = False # Reset empty object sent status
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a glm-4.5 / glm-4.6 format tool call."""
@@ -169,18 +187,38 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
- idx = text.find(self.bot_token)
- normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text:
- return StreamingParseResult(normal_text=normal_text, calls=[])
+ return StreamingParseResult(normal_text=text, calls=[])
+
+ # Extract all normal text (before, between, and after tool calls)
+ normal_text_parts = []
+ last_end = 0
+
+ # Find all tool call matches
+ for match in re.finditer(self.func_call_regex, text, re.DOTALL):
+ # Add text before this tool call
+ if match.start() > last_end:
+ normal_text_parts.append(text[last_end : match.start()])
+ last_end = match.end()
+
+ # Add any remaining text after the last tool call
+ if last_end < len(text):
+ normal_text_parts.append(text[last_end:])
+
+ # Combine all normal text parts
+ normal_text = "".join(normal_text_parts).strip()
+
+ # Parse tool calls
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
calls = []
try:
for match_result in match_result_list:
# Get function name
func_detail = self.func_detail_regex.search(match_result)
- func_name = func_detail.group(1)
- func_args = func_detail.group(2)
+ if func_detail is None:
+ continue
+ func_name = func_detail.group(1) if func_detail.group(1) else ""
+ func_args = func_detail.group(2) if func_detail.group(2) else ""
arguments = {}
if func_args:
pairs = self.func_arg_regex.findall(func_args)
@@ -212,7 +250,11 @@ def _get_value_type(self, func_name: str, key: str, tools: List[Tool]) -> str:
return arg_type
# Auto-detect type from value (best effort)
- first_chars = self._current_value.strip()[:10] if self._current_value else ""
+ first_chars = (
+ self._current_value.strip()[:10]
+ if self._current_value and self._current_value.strip()
+ else ""
+ )
if first_chars:
first_char = first_chars[0]
if first_char.isdigit() or first_char in ["-", "."]:
@@ -237,14 +279,14 @@ def _format_value_complete(self, value: str, value_type: str) -> str:
return json.dumps(value, ensure_ascii=False)
elif value_type == "number":
try:
- num = _convert_to_number(value.strip())
+ num = _convert_to_number(value.strip() if value else "")
return str(num)
except (ValueError, AttributeError):
# Fallback to string if not a valid number
logger.warning(
f"Failed to parse '{value}' as number, treating as string"
)
- return json.dumps(str(value), ensure_ascii=False)
+ return json.dumps(str(value) if value else "", ensure_ascii=False)
else:
# For object/array types, return as-is (should already be valid JSON)
return value
@@ -369,6 +411,179 @@ def _process_xml_to_json_streaming(
return json_output
+ def _extract_match_groups(self, match: re.Match) -> tuple[str, str, str]:
+ """Extract function name, arguments and end marker from regex match.
+
+ Args:
+ match: Regex match object
+
+ Returns:
+ (func_name, func_args_raw, is_tool_end)
+ """
+ func_name = match.group(1).strip()
+ func_args_raw = match.group(2).strip() if match.group(2) else ""
+ is_tool_end = match.group(3) or ""
+ return func_name, func_args_raw, is_tool_end
+
+ def _send_tool_name_if_needed(
+ self, func_name: str, has_arg_key: bool, is_tool_end: str
+ ) -> Optional[ToolCallItem]:
+ """Send tool name if needed.
+
+ Args:
+ func_name: Function name
+ has_arg_key: Whether current text contains Optional[ToolCallItem]:
+ """Process streaming arguments.
+
+ Args:
+ func_name: Function name
+ func_args_raw: Raw argument string
+ tools: List of available tools
+
+ Returns:
+ Tool call item with parameter updates or None
+ """
+ current_raw_length = len(func_args_raw)
+
+ if current_raw_length <= self._streamed_raw_length:
+ return None
+
+ # Get new raw XML content
+ raw_increment = func_args_raw[self._streamed_raw_length :]
+
+ # Convert XML to JSON using state machine
+ json_increment = self._process_xml_to_json_streaming(
+ raw_increment, func_name, tools
+ )
+
+ # CRITICAL: Update streamed length BEFORE early return
+ # Even if json_increment is empty, the input has been consumed by the state machine
+ self._streamed_raw_length = current_raw_length
+
+ if not json_increment:
+ return None
+
+ # Update state
+ self._last_arguments += json_increment
+ self.streamed_args_for_tool[self.current_tool_id] += json_increment
+
+ return ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters=json_increment,
+ )
+
+ def _finalize_tool_call(
+ self,
+ func_name: str,
+ func_args_raw: str,
+ tools: List[Tool],
+ match_end_pos: int,
+ current_text: str,
+ ) -> List[ToolCallItem]:
+ """Complete tool call processing.
+
+ Args:
+ func_name: Function name
+ func_args_raw: Raw argument string
+ tools: List of available tools
+ match_end_pos: Match end position
+ current_text: Current text
+
+ Returns:
+ List of tool call items to add
+ """
+ calls = []
+
+ # Handle no-arg function or need to close braces
+ if self._is_first_param and not self._sent_empty_object:
+ # No-arg function
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters="{}",
+ )
+ )
+ self._last_arguments += "{}"
+ self.streamed_args_for_tool[self.current_tool_id] += "{}"
+ self._sent_empty_object = True
+ elif not self._last_arguments.endswith("}") and not self._sent_empty_object:
+ # Need to close brace
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters="}",
+ )
+ )
+ self._last_arguments += "}"
+ self.streamed_args_for_tool[self.current_tool_id] += "}"
+ self._sent_empty_object = True
+
+ # Parse final arguments
+ if func_args_raw:
+ try:
+ pairs = self.func_arg_regex.findall(func_args_raw)
+ if pairs:
+ arguments = self._parse_argument_pairs(pairs, func_name, tools)
+ self.prev_tool_call_arr[self.current_tool_id][
+ "arguments"
+ ] = arguments
+ except Exception as e:
+ logger.debug(f"Failed to parse arguments: {e}", exc_info=True)
+
+ # Clean buffer
+ self._buffer = current_text[match_end_pos:]
+
+ # Reset state for next tool call
+ self._tool_call_completed = True
+ self.current_tool_id += 1
+ self._last_arguments = ""
+ self.current_tool_name_sent = False
+ self._streamed_raw_length = 0
+ self._reset_streaming_state()
+
+ return calls
+
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
@@ -405,140 +620,96 @@ def parse_streaming_increment(
# Could be start of tool call, keep buffering
return StreamingParseResult(normal_text="", calls=[])
+ # Extract any text before the first bot_token and return it as normal_text
+ normal_text = ""
+ first_bot_token_idx = current_text.find(self.bot_token)
+ if first_bot_token_idx > 0:
+ normal_text = current_text[:first_bot_token_idx]
+ current_text = current_text[first_bot_token_idx:]
+ # Update buffer to only include from the bot token onwards
+ self._buffer = current_text
+
if not hasattr(self, "_tool_indices"):
self._tool_indices = self._get_tool_indices(tools)
calls: list[ToolCallItem] = []
try:
# Try to match a partial or complete tool call
+ # Use a single flexible regex pattern that handles all cases
partial_match = re.search(
- pattern=r"(.*?)(.*?)?(|$)",
- string=current_text,
- flags=re.DOTALL,
+ r"(.*?)(?:()|$)",
+ current_text,
+ re.DOTALL,
)
- if partial_match:
- func_name = partial_match.group(1).strip()
- func_args_raw = partial_match.group(2).strip()
- is_tool_end = partial_match.group(3)
-
- # Initialize state if this is the first tool call
- if self.current_tool_id == -1:
- self.current_tool_id = 0
- self.prev_tool_call_arr = []
- self.streamed_args_for_tool = [""]
- self._streamed_raw_length = 0
- self.current_tool_name_sent = False
- self._reset_streaming_state()
-
- # Ensure we have enough entries in our tracking arrays
- while len(self.prev_tool_call_arr) <= self.current_tool_id:
- self.prev_tool_call_arr.append({})
- while len(self.streamed_args_for_tool) <= self.current_tool_id:
- self.streamed_args_for_tool.append("")
-
- # Send tool name first if not sent yet
- if not self.current_tool_name_sent:
- assert func_name, "func_name should not be empty"
- calls.append(
- ToolCallItem(
- tool_index=self.current_tool_id,
- name=func_name,
- parameters="",
- )
- )
- self.current_tool_name_sent = True
- self._streamed_raw_length = 0
- self._reset_streaming_state()
- # Store the tool call info
- self.prev_tool_call_arr[self.current_tool_id] = {
- "name": func_name,
- "arguments": {},
- }
- else:
- # Process XML to JSON streaming
- current_raw_length = len(func_args_raw)
- if current_raw_length > self._streamed_raw_length:
- # Get the new raw XML content
- raw_increment = func_args_raw[self._streamed_raw_length :]
+ if not partial_match:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
- # Convert XML increment to JSON increment using state machine
- json_increment = self._process_xml_to_json_streaming(
- raw_increment, func_name, tools
- )
+ # Extract match groups using helper method
+ func_name, func_args_raw, is_tool_end = self._extract_match_groups(
+ partial_match
+ )
- if json_increment:
- calls.append(
- ToolCallItem(
- tool_index=self.current_tool_id,
- name=None,
- parameters=json_increment,
- )
- )
- self._last_arguments += json_increment
- self.streamed_args_for_tool[
- self.current_tool_id
- ] += json_increment
-
- # Update the streamed length
- self._streamed_raw_length = current_raw_length
-
- if is_tool_end == self.eot_token:
- if self._is_first_param:
- empty_object = "{}"
- calls.append(
- ToolCallItem(
- tool_index=self.current_tool_id,
- name=None,
- parameters=empty_object,
- )
- )
- self._last_arguments += empty_object
- elif not self._last_arguments.endswith("}"):
- closing_brace = "}"
- calls.append(
- ToolCallItem(
- tool_index=self.current_tool_id,
- name=None,
- parameters=closing_brace,
- )
- )
- self._last_arguments += closing_brace
- self.streamed_args_for_tool[
- self.current_tool_id
- ] += closing_brace
-
- try:
- pairs = self.func_arg_regex.findall(func_args_raw)
- if pairs:
- arguments = self._parse_argument_pairs(
- pairs, func_name, tools
- )
- self.prev_tool_call_arr[self.current_tool_id][
- "arguments"
- ] = arguments
- except Exception as e:
- logger.debug(
- f"Failed to parse arguments: {e}", exc_info=True
- )
-
- # Remove the completed tool call from buffer
- self._buffer = current_text[partial_match.end(3) :]
-
- result = StreamingParseResult(normal_text="", calls=calls)
- self.current_tool_id += 1
- self._last_arguments = ""
- self.current_tool_name_sent = False
- self._streamed_raw_length = 0
- self._reset_streaming_state()
- return result
-
- return StreamingParseResult(normal_text="", calls=calls)
+ # Initialize tool call state if needed (keeping existing logic)
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ self.prev_tool_call_arr = []
+ self.streamed_args_for_tool = [""]
+ self._streamed_raw_length = 0
+ self.current_tool_name_sent = False # Reset for new tool call
+ self._reset_streaming_state()
+ # Check if this is a continuation of an existing tool call or a new one
+ elif not self.current_tool_name_sent:
+ # Only increment tool_id if we're truly starting a NEW tool call
+ # Don't increment if this is just the first time we're processing
+ # a tool call that was received in the buffer
+ # The key insight: only increment when we've COMPLETED a previous tool call
+ # and now see another bot_token in new_text
+ pass # Remove the problematic auto-increment logic
+
+ # Ensure tracking arrays are large enough (keeping existing logic)
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ # Determine if function name is complete by checking for in the full text
+ # This is important for streaming scenarios where args come in later chunks
+ has_arg_key = " Dict[str, Any]:
diff --git a/python/sglang/srt/function_call/glm4_moe_detector.py b/python/sglang/srt/function_call/glm4_moe_detector.py
index d1e684d2d4bc..3158e32166c3 100644
--- a/python/sglang/srt/function_call/glm4_moe_detector.py
+++ b/python/sglang/srt/function_call/glm4_moe_detector.py
@@ -87,8 +87,10 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult
for match_result in match_result_list:
# Get function name
func_detail = self.func_detail_regex.search(match_result)
- func_name = func_detail.group(1)
- func_args = func_detail.group(2)
+ if func_detail is None:
+ continue
+ func_name = func_detail.group(1) if func_detail.group(1) else ""
+ func_args = func_detail.group(2) if func_detail.group(2) else ""
pairs = self.func_arg_regex.findall(func_args)
arguments = {}
for arg_key, arg_value in pairs:
@@ -112,47 +114,217 @@ def parse_streaming_increment(
) -> StreamingParseResult:
"""
Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format.
+ Uses a state machine to convert XML to JSON incrementally for true character-by-character streaming.
+ Outputs JSON increments immediately as XML data arrives.
"""
self._buffer += new_text
current_text = self._buffer
- start = current_text.find(self.bot_token)
- if start == -1:
- self._buffer = ""
- if self.current_tool_id > 0:
- current_text = ""
- return StreamingParseResult(normal_text=current_text)
- # find ensures we find the first self.eot_token so there will be at most one tool_call in current_text[:end+len(self.eot_token)
- end = current_text.find(self.eot_token)
- if end != -1:
- # Initialize state if this is the first tool call
- if self.current_tool_id == -1:
- self.current_tool_id = 0
- self.prev_tool_call_arr = []
- self.streamed_args_for_tool = [""]
- # Ensure we have enough entries in our tracking arrays
- while len(self.prev_tool_call_arr) <= self.current_tool_id:
- self.prev_tool_call_arr.append({})
- while len(self.streamed_args_for_tool) <= self.current_tool_id:
- self.streamed_args_for_tool.append("")
- result = self.detect_and_parse(
- current_text[: end + len(self.eot_token)], tools=tools
+ # Check if we have a tool call
+ has_tool_call = self.bot_token in current_text
+
+ if not has_tool_call:
+ # Check if buffer could be the start of a tool call
+ # Keep buffer if it could be a partial match of bot_token
+ is_potential_start = any(
+ self.bot_token.startswith(current_text[-i:])
+ for i in range(1, min(len(current_text), len(self.bot_token)) + 1)
+ )
+
+ if not is_potential_start:
+ # Not a potential tool call, return as normal text
+ # Must return the entire buffer (current_text), not just new_text,
+ # because buffer may contain previously accumulated characters like '<'
+ # that turned out not to be part of a tool call
+ output_text = current_text
+ self._buffer = ""
+ if self.eot_token in output_text:
+ output_text = output_text.replace(self.eot_token, "")
+ return StreamingParseResult(normal_text=output_text)
+ else:
+ # Could be start of tool call, keep buffering
+ return StreamingParseResult(normal_text="", calls=[])
+
+ if not hasattr(self, "_tool_indices"):
+ self._tool_indices = self._get_tool_indices(tools)
+
+ calls: list[ToolCallItem] = []
+ try:
+ # Try to match a partial or complete tool call
+ partial_match = re.search(
+ pattern=r"(.*?)(?:\\n|\n)(.*?)(|$)",
+ string=current_text,
+ flags=re.DOTALL,
)
- if result.calls:
- self.prev_tool_call_arr[self.current_tool_id] = {
- "name": result.calls[0].name,
- "arguments": json.loads(result.calls[0].parameters),
- }
- self.streamed_args_for_tool[self.current_tool_id] = result.calls[
- 0
- ].parameters
- result.calls[0].tool_index = self.current_tool_id
- self.current_tool_id += 1
- self._buffer = current_text[end + len(self.eot_token) :]
- return result
- normal_text = current_text[:start]
- self._buffer = current_text[start:]
- return StreamingParseResult(normal_text=normal_text)
+ if partial_match:
+ func_name_raw = partial_match.group(1)
+ func_args_raw = partial_match.group(2)
+ is_tool_end = partial_match.group(3)
+
+ # Only proceed if we have a non-empty function name
+ if func_name_raw is None or not func_name_raw.strip():
+ # If we only have the start token without a function name,
+ # continue buffering until we get more content
+ return StreamingParseResult(normal_text="", calls=[])
+
+ func_name = func_name_raw.strip()
+ func_args_raw = func_args_raw.strip() if func_args_raw else ""
+
+ # Initialize state if this is the first tool call
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ self.prev_tool_call_arr = []
+ self.streamed_args_for_tool = [""]
+ self._streamed_raw_length = 0
+ self.current_tool_name_sent = False
+ self._reset_streaming_state()
+
+ # Ensure we have enough entries in our tracking arrays
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ # Send tool name first if not sent yet
+ if not self.current_tool_name_sent:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=func_name,
+ parameters="",
+ )
+ )
+ self.current_tool_name_sent = True
+ self._streamed_raw_length = 0
+ self._reset_streaming_state()
+ # Store the tool call info
+ self.prev_tool_call_arr[self.current_tool_id] = {
+ "name": func_name,
+ "arguments": {},
+ }
+ else:
+ # Process XML to JSON streaming
+ current_raw_length = len(func_args_raw)
+
+ if current_raw_length > self._streamed_raw_length:
+ # Get the new raw XML content
+ raw_increment = func_args_raw[self._streamed_raw_length :]
+
+ # Convert XML increment to JSON increment using state machine
+ json_increment = self._process_xml_to_json_streaming(
+ raw_increment, func_name, tools
+ )
+
+ # CRITICAL: Update streamed length BEFORE checking json_increment
+ # Even if json_increment is empty, the input has been consumed by the state machine
+ self._streamed_raw_length = current_raw_length
+
+ if json_increment:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters=json_increment,
+ )
+ )
+ self._last_arguments += json_increment
+ self.streamed_args_for_tool[
+ self.current_tool_id
+ ] += json_increment
+
+ if is_tool_end == self.eot_token:
+ if self._is_first_param:
+ empty_object = "{}"
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters=empty_object,
+ )
+ )
+ self._last_arguments += empty_object
+ elif not self._last_arguments.endswith("}"):
+ closing_brace = "}"
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters=closing_brace,
+ )
+ )
+ self._last_arguments += closing_brace
+ self.streamed_args_for_tool[
+ self.current_tool_id
+ ] += closing_brace
+
+ try:
+ pairs = self.func_arg_regex.findall(func_args_raw)
+ if pairs:
+ arguments = self._parse_argument_pairs(
+ pairs, func_name, tools
+ )
+ self.prev_tool_call_arr[self.current_tool_id][
+ "arguments"
+ ] = arguments
+ except Exception as e:
+ logger.debug(
+ f"Failed to parse arguments: {e}", exc_info=True
+ )
+
+ # Remove the completed tool call from buffer
+ self._buffer = current_text[partial_match.end(3) :]
+
+ result = StreamingParseResult(normal_text="", calls=calls)
+ self.current_tool_id += 1
+ self._last_arguments = ""
+ self.current_tool_name_sent = False
+ self._streamed_raw_length = 0
+ self._reset_streaming_state()
+ return result
+
+ return StreamingParseResult(normal_text="", calls=calls)
+
+ except Exception as e:
+ logger.error(f"Error in parse_streaming_increment: {e}", exc_info=True)
+ return StreamingParseResult(normal_text=current_text)
+
+ def _parse_argument_pairs(
+ self, pairs: List[Tuple[str, str]], func_name: str, tools: List[Tool]
+ ) -> Dict[str, Any]:
+ """Parse argument key-value pairs with type coercion.
+
+ Args:
+ pairs: List of (key, value) tuples from regex matching
+ func_name: Name of the function
+ tools: List of available tools
+
+ Returns:
+ Dictionary of parsed arguments
+ """
+ arguments = {}
+ for arg_key, arg_value in pairs:
+ arg_key = arg_key.strip()
+ arg_value = arg_value.strip()
+ arg_type = get_argument_type(func_name, arg_key, tools)
+ parsed_value, is_good_json = parse_arguments(arg_value, arg_type)
+
+ if arg_type == "string":
+ # Only convert to string if explicitly defined as string type
+ if isinstance(parsed_value, str):
+ arguments[arg_key] = parsed_value
+ elif isinstance(parsed_value, (dict, list)):
+ # If parsed as dict/list but schema says string, convert to JSON string
+ arguments[arg_key] = json.dumps(parsed_value, ensure_ascii=False)
+ else:
+ arguments[arg_key] = str(parsed_value)
+ elif arg_type is None:
+ # If type is not defined, keep the parsed value as-is
+ arguments[arg_key] = parsed_value if is_good_json else arg_value
+ else:
+ # For other types (number, object, array, etc.), use parsed value
+ arguments[arg_key] = parsed_value if is_good_json else arg_value
+
+ return arguments
def supports_structural_tag(self) -> bool:
return False
diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py
index a395237ad968..81cabe07fa3a 100644
--- a/test/registered/function_call/test_function_call_parser.py
+++ b/test/registered/function_call/test_function_call_parser.py
@@ -2387,6 +2387,21 @@ def check_single_todos(tool_result, expected):
)
check_single_todos(result, expected_output)
+ def test_empty_function_name_handling(self):
+ """Test that empty function name is handled gracefully without assertion error."""
+ # This test simulates the issue where the model outputs only the start token without a function name
+ chunks = [
+ "", # Start token only, no function name yet
+ "\n", # More content without function name
+ ]
+
+ for chunk in chunks:
+ # Should not raise AssertionError: func_name should not be empty
+ result = self.detector.parse_streaming_increment(chunk, self.tools)
+ # Should return empty calls without error
+ self.assertIsInstance(result, StreamingParseResult)
+ self.assertEqual(result.calls, [])
+
class TestGlm47MoeDetector(unittest.TestCase):
def setUp(self):
diff --git a/test/registered/function_call/test_glm47_moe_detector.py b/test/registered/function_call/test_glm47_moe_detector.py
new file mode 100644
index 000000000000..a046964064ce
--- /dev/null
+++ b/test/registered/function_call/test_glm47_moe_detector.py
@@ -0,0 +1,1176 @@
+import json
+import unittest
+
+from sglang.srt.entrypoints.openai.protocol import Function, Tool
+from sglang.srt.function_call.core_types import StreamingParseResult
+from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector
+from sglang.test.ci.ci_register import register_cpu_ci
+
+register_cpu_ci(1.0, "default")
+
+
+class TestGlm47MoeDetector(unittest.TestCase):
+ def setUp(self):
+ self.tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="get_weather",
+ description="Get weather information",
+ parameters={
+ "type": "object",
+ "properties": {
+ "city": {"type": "string", "description": "City name"},
+ "date": {"type": "string", "description": "Date"},
+ },
+ "required": ["city", "date"],
+ },
+ ),
+ ),
+ ]
+ self.detector = Glm47MoeDetector()
+
+ # ==================== Basic Parsing Tests (5) ====================
+
+ def test_single_tool_call(self):
+ """
+ Test basic single tool call parsing.
+
+ Scenario: Parse a complete tool call with two string parameters in a single text block.
+ Purpose: Verify the detector can correctly identify and extract function name and parameters
+ from a simple, well-formed tool call.
+ """
+ text = (
+ "get_weather"
+ "cityBeijing"
+ "date2024-06-27"
+ ""
+ )
+ result = self.detector.detect_and_parse(text, self.tools)
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "get_weather")
+ self.assertEqual(
+ result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}'
+ )
+ self.assertEqual(result.normal_text, "")
+
+ def test_multiple_tool_calls(self):
+ """
+ Test parsing multiple consecutive tool calls.
+
+ Scenario: Parse two complete tool calls back-to-back without any text in between.
+ Purpose: Verify the detector correctly handles multiple tool calls and resets state
+ between calls to avoid parameter leakage or ID conflicts.
+ """
+ text = (
+ "get_weather"
+ "cityBeijing"
+ "date2024-06-27"
+ ""
+ "get_weather"
+ "cityShanghai"
+ "date2024-06-28"
+ ""
+ )
+ result = self.detector.detect_and_parse(text, self.tools)
+ self.assertEqual(len(result.calls), 2)
+ self.assertEqual(result.calls[0].name, "get_weather")
+ self.assertEqual(
+ result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}'
+ )
+ self.assertEqual(result.calls[1].name, "get_weather")
+ self.assertEqual(
+ result.calls[1].parameters, '{"city": "Shanghai", "date": "2024-06-28"}'
+ )
+ self.assertEqual(result.normal_text, "")
+
+ def test_no_arg_function_non_streaming(self):
+ """
+ Test no-argument function call without streaming.
+
+ Scenario: Parse a tool call for a function that has no parameters (empty properties).
+ Purpose: Verify the detector generates a single empty object "{}" for no-argument functions
+ and does not duplicate empty parameter objects.
+ """
+ tools_with_no_args = [
+ Tool(
+ type="function",
+ function=Function(
+ name="list_filenames",
+ description="List filenames",
+ parameters={
+ "type": "object",
+ "properties": {},
+ },
+ ),
+ ),
+ ]
+
+ text = "list_filenames"
+ result = self.detector.detect_and_parse(text, tools_with_no_args)
+
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "list_filenames")
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params, {})
+
+ def test_invalid_tool_call(self):
+ """
+ Test handling of invalid tool calls.
+
+ Scenario: Attempt to parse a tool call with a function name that doesn't exist in the tool list.
+ Purpose: Verify the detector gracefully rejects invalid function calls and returns no calls
+ rather than throwing an error or accepting invalid input.
+ """
+ text = "invalid_funccityBeijing"
+ result = self.detector.detect_and_parse(text, self.tools)
+ self.assertEqual(len(result.calls), 0)
+
+ def test_array_argument_with_escaped_json(self):
+ """
+ Test array arguments containing escaped JSON strings.
+
+ Scenario: Parse tool calls with array parameters containing nested JSON objects with
+ escaped quotes (both backslash-escaped and raw escaped strings).
+ Purpose: Verify the detector properly handles JSON escaping without double-escaping,
+ preserving special characters like backslashes in paths and newline sequences.
+ """
+ tools_with_array = [
+ Tool(
+ type="function",
+ function=Function(
+ name="todo_write",
+ description="Write todos",
+ parameters={
+ "type": "object",
+ "properties": {
+ "todos": {
+ "type": "array",
+ "description": "The updated todo list",
+ }
+ },
+ "required": ["todos"],
+ },
+ ),
+ ),
+ ]
+
+ def check_params(result):
+ self.assertEqual(1, len(result.calls))
+ self.assertEqual("todo_write", result.calls[0].name)
+ params = json.loads(result.calls[0].parameters)
+ self.assertIsInstance(params["todos"], list)
+ self.assertEqual(4, len(params["todos"]))
+ self.assertEqual("1", params["todos"][0]["id"])
+ self.assertEqual(
+ "Check for hard-coded issues in the backend code",
+ params["todos"][0]["task"],
+ )
+ self.assertEqual("in_progress", params["todos"][0]["status"])
+ self.assertEqual("2", params["todos"][1]["id"])
+ self.assertEqual(
+ "Check for hard-coded issues in the frontend code",
+ params["todos"][1]["task"],
+ )
+ self.assertEqual("pending", params["todos"][1]["status"])
+ self.assertEqual("3", params["todos"][2]["id"])
+ self.assertEqual(
+ "Check for code violating the Single Responsibility Principle",
+ params["todos"][2]["task"],
+ )
+ self.assertEqual("pending", params["todos"][2]["status"])
+ self.assertEqual("4", params["todos"][3]["id"])
+ self.assertEqual(
+ "Generate a rectification proposal report", params["todos"][3]["task"]
+ )
+ self.assertEqual("pending", params["todos"][3]["status"])
+
+ # Test with normal escaped JSON in XML
+ result = self.detector.detect_and_parse(
+ """todo_writetodos[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}]
+""",
+ tools_with_array,
+ )
+ check_params(result)
+
+ # Test with raw string escaped JSON
+ result = self.detector.detect_and_parse(
+ r"""todo_writetodos[{\"id\": \"1\", \"task\": \"Check for hard-coded issues in the backend code\", \"status\": \"in_progress\"}, {\"id\": \"2\", \"task\": \"Check for hard-coded issues in the frontend code\", \"status\": \"pending\"}, {\"id\": \"3\", \"task\": \"Check for code violating the Single Responsibility Principle\", \"status\": \"pending\"}, {\"id\": \"4\", \"task\": \"Generate a rectification proposal report\", \"status\": \"pending\"}]
+""",
+ tools_with_array,
+ )
+ check_params(result)
+
+ def check_single_todos(tool_result, expected):
+ self.assertEqual(1, len(tool_result.calls))
+ self.assertEqual("todo_write", tool_result.calls[0].name)
+ params = json.loads(tool_result.calls[0].parameters)
+ self.assertIsInstance(params["todos"], list)
+ self.assertEqual(1, len(params["todos"]))
+ self.assertEqual("1", params["todos"][0]["id"])
+ self.assertEqual(expected, params["todos"][0]["task"])
+ self.assertEqual("pending", params["todos"][0]["status"])
+
+ # Test with escaped backslashes (Windows paths)
+ expected_path = r"Check file at C:\Users\test.txt"
+ result = self.detector.detect_and_parse(
+ """todo_writetodos[{\"id\": \"1\", \"task\": \"Check file at C:\\\\Users\\\\test.txt\", \"status\": \"pending\"}]""",
+ tools_with_array,
+ )
+ check_single_todos(result, expected_path)
+
+ # Test with literal backslash-n (not newline)
+ expected_output = r"Print \n to see newline"
+ result = self.detector.detect_and_parse(
+ """todo_writetodos[{\"id\": \"1\", \"task\": \"Print \\\\n to see newline\",\"status\": \"pending\"}]""",
+ tools_with_array,
+ )
+ check_single_todos(result, expected_output)
+
+ # ==================== MTP Core Scenarios (3) ====================
+
+ def test_mtp_func_and_string_split(self):
+ """
+ Test MTP-style function name and string parameter value splitting across chunks.
+
+ Scenario: Simulate Model Token Provider (MTP) behavior where function names and string
+ parameter values are split mid-word across multiple chunks.
+ Purpose: This is the MOST CRITICAL test - verify the detector correctly reassembles:
+ - Function name split as "create_ta" + "sk"
+ - String values split as "Go to Bei" + "jing" and "San Fran" + "cisco"
+ These splits mimic real MTP output where tokenization breaks words arbitrarily.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="create_task",
+ parameters={
+ "type": "object",
+ "properties": {
+ "title": {"type": "string"},
+ "location": {"type": "string"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ chunks = [
+ "I'll create a task.", # normal text before tool call
+ "create_ta", # function name split mid-word
+ "sktitleGo to Bei", # function name completes, param value splits
+ "jing", # first parameter value completes
+ "locationSan Fran", # second parameter value splits
+ "cisco", # second parameter and tool call complete
+ ]
+
+ detector = Glm47MoeDetector()
+ all_calls = []
+ all_normal_text = ""
+
+ for chunk in chunks:
+ result = detector.parse_streaming_increment(chunk, tools)
+ all_calls.extend(result.calls)
+ all_normal_text += result.normal_text
+
+ # Verify normal text is preserved
+ self.assertEqual(all_normal_text, "I'll create a task.")
+
+ # Verify function call
+ func_calls = [c for c in all_calls if c.name]
+ self.assertEqual(len(func_calls), 1)
+ self.assertEqual(
+ func_calls[0].name, "create_task"
+ ) # "create_ta" + "sk" reassembled
+
+ # Verify parameter reassembly
+ full_params = "".join([c.parameters for c in all_calls if c.parameters])
+ params = json.loads(full_params)
+ self.assertEqual(
+ params["title"], "Go to Beijing"
+ ) # "Go to Bei" + "jing" reassembled
+ self.assertEqual(
+ params["location"], "San Francisco"
+ ) # "San Fran" + "cisco" reassembled
+
+ def test_mtp_noarg_and_multiple_calls(self):
+ """
+ Test MTP-style no-argument function and multiple tool calls with state reset.
+
+ Scenario: Stream a no-argument function call followed by a regular function call,
+ simulating MTP's output pattern where function completion triggers state reset.
+ Purpose: Verify:
+ - No-argument functions emit exactly ONE empty object "{}", not duplicates
+ - State properly resets between consecutive tool calls (tool_index increments)
+ - Second tool call doesn't inherit parameters from first call
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="list_files",
+ parameters={
+ "type": "object",
+ "properties": {},
+ },
+ ),
+ ),
+ Tool(
+ type="function",
+ function=Function(
+ name="get_weather",
+ parameters={
+ "type": "object",
+ "properties": {
+ "city": {"type": "string"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ chunks = [
+ "list_files", # no-arg function, complete in one chunk
+ "get_weathercityBeijing",
+ ]
+
+ detector = Glm47MoeDetector()
+ all_calls = []
+
+ for chunk in chunks:
+ result = detector.parse_streaming_increment(chunk, tools)
+ all_calls.extend(result.calls)
+
+ # Verify two distinct tool calls
+ func_calls = [c for c in all_calls if c.name]
+ self.assertEqual(len(func_calls), 2)
+ self.assertEqual(func_calls[0].name, "list_files")
+ self.assertEqual(func_calls[1].name, "get_weather")
+
+ # Verify no duplicate empty objects for no-arg function
+ empty_object_calls = [c for c in all_calls if c.parameters == "{}"]
+ self.assertLessEqual(
+ len(empty_object_calls),
+ 1,
+ "No-argument function should emit at most one empty object",
+ )
+
+ # Verify second call has correct parameters
+ weather_params = [
+ c.parameters for c in all_calls if c.parameters and c.parameters != "{}"
+ ]
+ if weather_params:
+ full_params = "".join(weather_params)
+ params = json.loads(full_params)
+ self.assertEqual(params["city"], "Beijing")
+
+ def test_mtp_number_and_complex_json(self):
+ """
+ Test MTP-style number parameters and complex JSON array splitting.
+
+ Scenario: Parse tool calls with number parameters (int and float) and JSON arrays
+ split across chunks, including splits within JSON structure.
+ Purpose: Verify:
+ - Number types (5.5, 10) are preserved as numbers, not strings
+ - JSON array content split as "description" + ": \"" maintains validity
+ - Nested JSON objects in arrays are correctly reconstructed
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="create_todos",
+ parameters={
+ "type": "object",
+ "properties": {
+ "priority": {"type": "number"},
+ "count": {"type": "integer"},
+ "items": {"type": "array"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ chunks = [
+ "create_todos",
+ "priority5.5", # float number
+ "count10", # integer number
+ 'items[{"description', # JSON array splits mid-key
+ '": "Test', # key completes, value starts
+ 'Todo 1"}, {"description": "TestTodo 2"}]',
+ ]
+
+ detector = Glm47MoeDetector()
+ all_calls = []
+
+ for chunk in chunks:
+ result = detector.parse_streaming_increment(chunk, tools)
+ all_calls.extend(result.calls)
+
+ # Verify function name
+ func_calls = [c for c in all_calls if c.name]
+ self.assertEqual(len(func_calls), 1)
+ self.assertEqual(func_calls[0].name, "create_todos")
+
+ # Verify parameters - numbers and JSON array
+ full_params = "".join([c.parameters for c in all_calls if c.parameters])
+ params = json.loads(full_params)
+
+ # Number types should be preserved
+ self.assertIsInstance(params["priority"], (int, float))
+ self.assertEqual(params["priority"], 5.5)
+ self.assertIsInstance(params["count"], int)
+ self.assertEqual(params["count"], 10)
+
+ # JSON array should be correctly reconstructed
+ self.assertIsInstance(params["items"], list)
+ self.assertEqual(len(params["items"]), 2)
+ self.assertEqual(params["items"][0]["description"], "TestTodo 1")
+ self.assertEqual(params["items"][1]["description"], "TestTodo 2")
+
+ # ==================== Streaming Basics (3) ====================
+
+ def test_streaming_tool_call(self):
+ """
+ Test basic streaming incremental parsing of a single tool call.
+
+ Scenario: Parse a tool call split across 4 chunks with natural boundaries
+ (function name, first param, second param, closing tag).
+ Purpose: Verify basic streaming functionality works correctly and accumulates
+ parameters progressively across chunks.
+ """
+ chunks = [
+ "get_weather",
+ "cityBeijing",
+ "date2024-06-27",
+ "",
+ ]
+ tool_calls = []
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, self.tools)
+ for tool_call_chunk in result.calls:
+ if (
+ hasattr(tool_call_chunk, "tool_index")
+ and tool_call_chunk.tool_index is not None
+ ):
+ while len(tool_calls) <= tool_call_chunk.tool_index:
+ tool_calls.append({"name": "", "parameters": ""})
+ tc = tool_calls[tool_call_chunk.tool_index]
+ if tool_call_chunk.name:
+ tc["name"] = tool_call_chunk.name
+ if tool_call_chunk.parameters:
+ tc["parameters"] += tool_call_chunk.parameters
+ self.assertEqual(len(tool_calls), 1)
+ self.assertEqual(tool_calls[0]["name"], "get_weather")
+ self.assertEqual(
+ tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}'
+ )
+
+ def test_streaming_multiple_tool_calls(self):
+ """
+ Test streaming incremental parsing of multiple consecutive tool calls.
+
+ Scenario: Stream two complete tool calls with the transition ""
+ occurring within a single chunk.
+ Purpose: Verify streaming correctly handles multiple tool calls and properly increments
+ tool_index for each new call.
+ """
+ chunks = [
+ "get_weather",
+ "cityBeijing",
+ "date2024-06-27",
+ "get_weather", # two tool calls transition in same chunk
+ "cityShanghai",
+ "date2024-06-28",
+ "",
+ ]
+ tool_calls = []
+ for chunk in chunks:
+ result = self.detector.parse_streaming_increment(chunk, self.tools)
+ for tool_call_chunk in result.calls:
+ if (
+ hasattr(tool_call_chunk, "tool_index")
+ and tool_call_chunk.tool_index is not None
+ ):
+ while len(tool_calls) <= tool_call_chunk.tool_index:
+ tool_calls.append({"name": "", "parameters": ""})
+ tc = tool_calls[tool_call_chunk.tool_index]
+ if tool_call_chunk.name:
+ tc["name"] = tool_call_chunk.name
+ if tool_call_chunk.parameters:
+ tc["parameters"] += tool_call_chunk.parameters
+ self.assertEqual(len(tool_calls), 2)
+ self.assertEqual(tool_calls[0]["name"], "get_weather")
+ self.assertEqual(
+ tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}'
+ )
+ self.assertEqual(tool_calls[1]["name"], "get_weather")
+ self.assertEqual(
+ tool_calls[1]["parameters"], '{"city": "Shanghai", "date": "2024-06-28"}'
+ )
+
+ def test_normal_text_before_tool_call(self):
+ """
+ Test preservation of normal text (including punctuation) before tool calls.
+
+ Scenario: Parse chunks containing normal text with various punctuation marks
+ (English and Chinese) immediately followed by tool call tags.
+ Purpose: Verify normal text is preserved in result.normal_text and not lost when
+ tool call parsing begins. This consolidates 6 previous Chinese punctuation tests.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="list_dir",
+ parameters={
+ "type": "object",
+ "properties": {
+ "path": {"type": "string"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ test_cases = [
+ ("Sure, let me help.list_dir", "English with period"),
+ ("结构:list_dir", "Chinese colon"),
+ ("问题。list_dir", "Chinese period"),
+ ("Complete!list_dir", "English exclamation"),
+ ("说明;list_dir", "Chinese semicolon"),
+ ]
+
+ for text, description in test_cases:
+ with self.subTest(description=description):
+ detector = Glm47MoeDetector()
+ result = detector.parse_streaming_increment(text, tools)
+
+ before_token = text.split("")[0]
+ self.assertIn(
+ before_token,
+ result.normal_text,
+ f"Should preserve '{before_token}' in '{description}'",
+ )
+
+ # ==================== Boundary Cases (9) ====================
+
+ def test_boundary_empty_param_value(self):
+ """
+ Test handling of empty parameter values.
+
+ Scenario: Parse a tool call where a parameter value is an empty string.
+ Purpose: Verify the detector correctly handles empty strings as valid parameter values
+ and doesn't skip or error on them.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="create_note",
+ parameters={
+ "type": "object",
+ "properties": {
+ "title": {"type": "string"},
+ "content": {"type": "string"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ text = "create_notetitleTestcontent"
+ result = self.detector.detect_and_parse(text, tools)
+
+ self.assertEqual(len(result.calls), 1)
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["title"], "Test")
+ self.assertEqual(params["content"], "") # empty string should be preserved
+
+ def test_boundary_param_value_extreme_split(self):
+ """
+ Test extreme parameter value splitting - one character per chunk.
+
+ Scenario: Stream a parameter value where each character arrives in a separate chunk,
+ representing worst-case MTP tokenization.
+ Purpose: Stress test the buffer reassembly mechanism to ensure it can handle
+ extremely granular chunk boundaries without data loss or corruption.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="search",
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {"type": "string"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ chunks = [
+ "searchqueryN",
+ "e",
+ "w ",
+ "Y",
+ "o",
+ "rk",
+ ]
+
+ detector = Glm47MoeDetector()
+ all_calls = []
+
+ for chunk in chunks:
+ result = detector.parse_streaming_increment(chunk, tools)
+ all_calls.extend(result.calls)
+
+ full_params = "".join([c.parameters for c in all_calls if c.parameters])
+ params = json.loads(full_params)
+ self.assertEqual(
+ params["query"], "New York"
+ ) # all characters correctly reassembled
+
+ def test_boundary_param_value_with_special_chars(self):
+ """
+ Test parameter values containing special characters and escape sequences.
+
+ Scenario: Parse parameter values with quotes, backslashes, newlines, and other
+ special characters that require JSON escaping.
+ Purpose: Verify special characters are properly escaped/unescaped and preserved
+ through the parsing pipeline without corruption.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="execute_command",
+ parameters={
+ "type": "object",
+ "properties": {
+ "command": {"type": "string"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ # Test with single quotes (no escaping needed)
+ text = "execute_commandcommandecho 'Hello World'"
+ result = self.detector.detect_and_parse(text, tools)
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["command"], "echo 'Hello World'")
+
+ # Test with spaces and special chars that don't need escaping
+ text = "execute_commandcommandecho Hello & World"
+ result = self.detector.detect_and_parse(text, tools)
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["command"], "echo Hello & World")
+
+ def test_boundary_json_deeply_nested(self):
+ """
+ Test deeply nested JSON structures in parameter values.
+
+ Scenario: Parse a parameter containing a deeply nested JSON object with multiple levels.
+ Purpose: Verify the detector can handle complex nested structures without stack overflow
+ or parsing errors.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="process_data",
+ parameters={
+ "type": "object",
+ "properties": {
+ "data": {"type": "object"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ nested_json = (
+ '{"level1": {"level2": {"level3": {"level4": {"value": "deep"}}}}}'
+ )
+ text = f"process_datadata{nested_json}"
+
+ result = self.detector.detect_and_parse(text, tools)
+ params = json.loads(result.calls[0].parameters)
+
+ # Navigate through nested structure
+ self.assertEqual(
+ params["data"]["level1"]["level2"]["level3"]["level4"]["value"], "deep"
+ )
+
+ def test_boundary_json_empty_structures(self):
+ """
+ Test empty JSON structures (empty objects and arrays) in parameters.
+
+ Scenario: Parse parameters containing empty objects {} and empty arrays [].
+ Purpose: Verify empty structures are preserved and not confused with no-argument
+ function empty parameter generation.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="create_structure",
+ parameters={
+ "type": "object",
+ "properties": {
+ "empty_obj": {"type": "object"},
+ "empty_arr": {"type": "array"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ text = "create_structureempty_obj{}empty_arr[]"
+ result = self.detector.detect_and_parse(text, tools)
+
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["empty_obj"], {})
+ self.assertEqual(params["empty_arr"], [])
+
+ def test_boundary_multi_tags_one_chunk(self):
+ """
+ Test multiple XML tags appearing in a single chunk.
+
+ Scenario: Parse chunks where multiple complete tags (arg_key, arg_value, etc.)
+ appear together without any chunk boundaries between them.
+ Purpose: Verify the regex-based tag extraction correctly handles multiple tags
+ in one chunk and processes them in the correct order.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="multi_param",
+ parameters={
+ "type": "object",
+ "properties": {
+ "a": {"type": "string"},
+ "b": {"type": "string"},
+ "c": {"type": "string"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ # All three parameters in one chunk
+ text = "multi_parama1b2c3"
+ result = self.detector.detect_and_parse(text, tools)
+
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["a"], "1")
+ self.assertEqual(params["b"], "2")
+ self.assertEqual(params["c"], "3")
+
+ def test_boundary_normal_text_mixed_with_tool(self):
+ """
+ Test normal text interleaved with tool calls.
+
+ Scenario: Parse text with normal text before and after tool calls.
+ Purpose: Verify normal text segments are correctly separated from tool call parsing
+ and preserved in the normal_text output.
+
+ NOTE: Currently, the detector only captures text BEFORE the first tool call.
+ Text after tool calls is not returned (known limitation).
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="action",
+ parameters={
+ "type": "object",
+ "properties": {},
+ },
+ ),
+ ),
+ ]
+
+ text = "First I'll do this.actionThen I'll do that."
+ result = self.detector.detect_and_parse(text, tools)
+
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "action")
+ # Currently only text before tool call is captured
+ self.assertIn("First I'll do this.", result.normal_text)
+ # TODO: Text after tool call should also be preserved but currently isn't
+ self.assertIn("Then I'll do that.", result.normal_text)
+
+ def test_boundary_number_edge_values(self):
+ """
+ Test edge-case number values (zero, negative, scientific notation).
+
+ Scenario: Parse parameters with various numeric edge cases to ensure proper type handling.
+ Purpose: Verify the detector correctly preserves number types for edge values and doesn't
+ convert them to strings or lose precision.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="calculate",
+ parameters={
+ "type": "object",
+ "properties": {
+ "zero": {"type": "number"},
+ "negative": {"type": "number"},
+ "large": {"type": "number"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ text = "calculatezero0negative-42.5large1e10"
+ result = self.detector.detect_and_parse(text, tools)
+
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["zero"], 0)
+ self.assertEqual(params["negative"], -42.5)
+ self.assertEqual(params["large"], 1e10)
+
+ def test_boundary_type_string_with_numeric_content(self):
+ """
+ Test string parameters that contain numeric-looking content.
+
+ Scenario: Parse string parameters with values like "123" or "45.67" that look like
+ numbers but should remain strings based on parameter schema.
+ Purpose: Verify type preservation based on schema definition, not content appearance.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="store_data",
+ parameters={
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "string"
+ }, # string type despite numeric content
+ "code": {"type": "string"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ text = "store_dataid12345code67.89"
+ result = self.detector.detect_and_parse(text, tools)
+
+ params = json.loads(result.calls[0].parameters)
+ # Should be strings, not numbers
+ self.assertIsInstance(params["id"], str)
+ self.assertIsInstance(params["code"], str)
+ self.assertEqual(params["id"], "12345")
+ self.assertEqual(params["code"], "67.89")
+
+ # ==================== Error Handling (2) ====================
+
+ def test_error_undefined_tool(self):
+ """
+ Test error handling for undefined tool names.
+
+ Scenario: Attempt to call a function that doesn't exist in the provided tools list.
+ Purpose: Verify the detector gracefully handles undefined tools by returning an empty
+ call list rather than crashing or producing malformed output.
+ """
+ text = "nonexistent_functionparamvalue"
+ result = self.detector.detect_and_parse(text, self.tools)
+
+ # Should not crash, should return empty calls
+ self.assertEqual(len(result.calls), 0)
+
+ def test_error_incomplete_buffer_at_end(self):
+ """
+ Test handling of incomplete tool calls at end of stream.
+
+ Scenario: Streaming ends with an incomplete tool call (e.g., missing closing tag).
+ Purpose: Verify the detector handles incomplete buffers gracefully without throwing
+ exceptions, as streaming may end mid-parse in real scenarios.
+ """
+ chunks = [
+ "get_weathercityBeijing",
+ # Stream ends here, no closing tags
+ ]
+
+ detector = Glm47MoeDetector()
+
+ for chunk in chunks:
+ result = detector.parse_streaming_increment(chunk, self.tools)
+ # Should not crash
+ self.assertIsInstance(result, StreamingParseResult)
+
+ # Incomplete call should not be in results
+ # (or may be partially present - main thing is no exception)
+
+ # ==================== Streamed Raw Length Bug Tests (3) ====================
+
+ def test_streamed_raw_length_incomplete_xml_tag(self):
+ """
+ Test that _streamed_raw_length is updated even when json_increment is empty.
+
+ Scenario: Stream XML content that is split at an incomplete tag boundary,
+ causing the state machine to buffer without producing JSON output.
+ Purpose: Verify that _streamed_raw_length is updated regardless of whether
+ json_increment is empty, preventing reprocessing of the same input.
+
+ This tests the bug where:
+ 1. raw_increment is extracted from func_args_raw[self._streamed_raw_length:]
+ 2. _process_xml_to_json_streaming() returns empty string (buffering state)
+ 3. If _streamed_raw_length is NOT updated before the early return,
+ the next call will reprocess the same raw_increment
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="get_weather",
+ parameters={
+ "type": "object",
+ "properties": {
+ "city": {"type": "string"},
+ "temperature": {"type": "number"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ # Simulate streaming chunks where XML tags are split
+ chunks = [
+ "get_weather",
+ "cityBei", # Split in middle of value
+ "jing", # Complete the value
+ "temperature2", # Split numeric value
+ "5",
+ ]
+
+ detector = Glm47MoeDetector()
+ all_calls = []
+ collected_params = ""
+
+ for i, chunk in enumerate(chunks):
+ result = detector.parse_streaming_increment(chunk, tools)
+ all_calls.extend(result.calls)
+
+ # Collect parameters
+ for call in result.calls:
+ if call.parameters:
+ collected_params += call.parameters
+
+ # Verify complete parameters were collected without duplication
+ if collected_params:
+ params = json.loads(collected_params)
+ self.assertEqual(params["city"], "Beijing")
+ self.assertEqual(params["temperature"], 25)
+
+ # Critical: Verify no duplicate JSON output due to reprocessing
+ # Count occurrences of "city" key - should appear exactly once
+ city_count = collected_params.count('"city"')
+ self.assertEqual(
+ city_count,
+ 1,
+ f"'city' key appears {city_count} times, expected 1. "
+ f"This indicates input reprocessing bug.",
+ )
+
+ def test_streamed_raw_length_tag_split_across_chunks(self):
+ """
+ Test _streamed_raw_length update when tag is split across chunk boundaries.
+
+ Scenario: XML tags themselves are split across chunks (e.g., "").
+ Purpose: Verify that even when the state machine is buffering partial tags,
+ _streamed_raw_length is correctly updated to prevent reprocessing.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="search",
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {"type": "string"},
+ "limit": {"type": "integer"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ # Split tags in extreme positions
+ chunks = [
+ "searchqueryPython progra", # Complete tag, split value
+ "mminglimit10",
+ ]
+
+ detector = Glm47MoeDetector()
+ all_params = ""
+
+ for chunk in chunks:
+ result = detector.parse_streaming_increment(chunk, tools)
+ for call in result.calls:
+ if call.parameters:
+ all_params += call.parameters
+
+ # Verify correct reassembly
+ params = json.loads(all_params)
+ self.assertEqual(params["query"], "Python programming")
+ self.assertEqual(params["limit"], 10)
+
+ # Verify no duplication in output
+ query_count = all_params.count('"query"')
+ limit_count = all_params.count('"limit"')
+ self.assertEqual(query_count, 1, "query key duplicated - reprocessing bug")
+ self.assertEqual(limit_count, 1, "limit key duplicated - reprocessing bug")
+
+ def test_streamed_raw_length_buffer_only_partial_tag(self):
+ """
+ Test that _streamed_raw_length updates even when state machine returns empty.
+
+ Scenario: Send increment that is ONLY a partial opening tag that state machine
+ must buffer completely without producing any JSON output.
+ Purpose: Force json_increment to be empty string to expose the bug where
+ _streamed_raw_length is not updated before early return.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="test_func",
+ parameters={
+ "type": "object",
+ "properties": {
+ "key1": {"type": "string"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ # Manually call _process_arguments_streaming to have precise control
+ detector = Glm47MoeDetector()
+ detector.current_tool_id = 0
+ detector.current_tool_name_sent = True
+ detector._reset_streaming_state()
+ detector.streamed_args_for_tool = [""]
+ detector._streamed_raw_length = 0
+
+ # First call: Complete tag that produces JSON output
+ func_args_1 = "key1va"
+ result_1 = detector._process_arguments_streaming(
+ "test_func", func_args_1, tools
+ )
+
+ # Should produce JSON output: {"key1": "va (partial)
+ self.assertIsNotNone(result_1)
+ self.assertGreater(len(result_1.parameters), 0)
+ initial_length = detector._streamed_raw_length
+ self.assertEqual(initial_length, len(func_args_1))
+
+ # Second call: Add just partial closing tag - state machine will buffer this
+ # without producing JSON (it's waiting to see if is complete)
+ func_args_2 = func_args_1 + "<" # Add partial tag
+ result_2 = detector._process_arguments_streaming(
+ "test_func", func_args_2, tools
+ )
+
+ # This is the critical test: if _streamed_raw_length is NOT updated when
+ # json_increment is empty, then detector._streamed_raw_length will still be
+ # at initial_length, and the next call will reprocess the "<" character
+
+ # Check if length was updated (bug test)
+ updated_length = detector._streamed_raw_length
+
+ # BUG: If code has bug, updated_length will equal initial_length
+ # FIXED: If code is correct, updated_length should equal len(func_args_2)
+ self.assertEqual(
+ updated_length,
+ len(func_args_2),
+ "Bug detected: _streamed_raw_length not updated when json_increment is empty. "
+ f"Expected {len(func_args_2)}, got {updated_length}",
+ )
+
+ def test_streamed_raw_length_multiple_empty_returns(self):
+ """
+ Test consecutive chunks that produce empty json_increment.
+
+ Scenario: Multiple consecutive chunks that all result in empty json_increment
+ as the state machine buffers complex nested structures.
+ Purpose: Verify _streamed_raw_length advances correctly through multiple
+ empty-return cycles without getting stuck or reprocessing.
+ """
+ tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="update_settings",
+ parameters={
+ "type": "object",
+ "properties": {
+ "name": {"type": "string"},
+ "value": {"type": "string"},
+ },
+ },
+ ),
+ ),
+ ]
+
+ # Split XML at positions that may cause state machine buffering
+ chunks = [
+ "update_settingsna", # Split in tag name
+ "meco", # Complete tag start, split value # codespell:ignore ue
+ "nf", # Continue value
+ "ig_v1val", # Complete value, split next key
+ "ueena", # Complete key name, split value # codespell:ignore ue
+ "bled", # Complete everything
+ ]
+
+ detector = Glm47MoeDetector()
+ all_params = ""
+
+ for i, chunk in enumerate(chunks):
+ result = detector.parse_streaming_increment(chunk, tools)
+
+ for call in result.calls:
+ if call.parameters:
+ all_params += call.parameters
+
+ # Verify final output is correct
+ self.assertGreater(len(all_params), 0, "Should have generated some parameters")
+ params = json.loads(all_params)
+ self.assertEqual(params["name"], "config_v1")
+ self.assertEqual(params["value"], "enabled")
+
+ # Verify no duplicate keys due to reprocessing
+ name_count = all_params.count('"name"')
+ value_count = all_params.count('"value"')
+ self.assertEqual(
+ name_count,
+ 1,
+ f"'name' appears {name_count} times - indicates reprocessing bug",
+ )
+ self.assertEqual(
+ value_count,
+ 1,
+ f"'value' appears {value_count} times - indicates reprocessing bug",
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
From 5bdb7334411de1d2e595aec2d05e41435faa1e0f Mon Sep 17 00:00:00 2001
From: Leoyzen
Date: Sat, 10 Jan 2026 01:21:10 +0800
Subject: [PATCH 3/3] Fix GLM-4.7 MoE Detector complex JSON Schema type parsing
(#15753)
---
.../srt/function_call/glm47_moe_detector.py | 53 +-
.../srt/function_call/glm4_moe_detector.py | 371 +++++++++-
python/sglang/srt/function_call/utils.py | 105 ++-
.../function_call/test_glm47_moe_detector.py | 681 +++++++++++++++++-
4 files changed, 1165 insertions(+), 45 deletions(-)
diff --git a/python/sglang/srt/function_call/glm47_moe_detector.py b/python/sglang/srt/function_call/glm47_moe_detector.py
index 0cf3a1b622e6..0b25b11eb379 100644
--- a/python/sglang/srt/function_call/glm47_moe_detector.py
+++ b/python/sglang/srt/function_call/glm47_moe_detector.py
@@ -12,6 +12,7 @@
ToolCallItem,
_GetInfoFunc,
)
+from sglang.srt.function_call.utils import infer_type_from_json_schema
logger = logging.getLogger(__name__)
@@ -31,6 +32,14 @@ def get_argument_type(
) -> Optional[str]:
"""Get the expected type of a function argument from tool definitions.
+ Supports complex JSON Schema definitions including:
+ - Direct type field (including type arrays)
+ - anyOf/oneOf: parameter can be any of multiple types
+ - enum: parameter must be one of enum values
+ - allOf: parameter must satisfy all type definitions
+ - properties: inferred as object type
+ - items: inferred as array type
+
Args:
func_name: Name of the function/tool
arg_key: Name of the argument
@@ -58,7 +67,8 @@ def get_argument_type(
arg_spec = properties.get(arg_key)
if isinstance(arg_spec, dict):
- return arg_spec.get("type")
+ # Use the new type inference function for complex JSON Schema support
+ return infer_type_from_json_schema(arg_spec)
return None
@@ -243,25 +253,48 @@ def _get_value_type(self, func_name: str, key: str, tools: List[Tool]) -> str:
tools: List of available tools
Returns:
- Type string: 'string', 'number', or 'object'
+ Type string: 'string', 'number', 'object', 'array', or 'boolean'
"""
arg_type = get_argument_type(func_name, key, tools)
if arg_type:
return arg_type
- # Auto-detect type from value (best effort)
- first_chars = (
- self._current_value.strip()[:10]
- if self._current_value and self._current_value.strip()
- else ""
- )
- if first_chars:
- first_char = first_chars[0]
+ # Improved auto-detection type from value (best effort)
+ value_content = self._current_value.strip() if self._current_value else ""
+
+ if not value_content:
+ return "string"
+
+ # Try to parse as valid JSON first
+ try:
+ parsed = json.loads(value_content)
+ if isinstance(parsed, dict):
+ return "object"
+ elif isinstance(parsed, list):
+ return "array"
+ elif isinstance(parsed, bool):
+ return "boolean"
+ elif isinstance(parsed, (int, float)):
+ return "number"
+ # For string values, check if they look like numbers
+ elif isinstance(parsed, str):
+ if parsed.isdigit() or (
+ parsed.startswith("-") and parsed[1:].isdigit()
+ ):
+ return "number"
+ return "string"
+ except json.JSONDecodeError:
+ # Not valid JSON, try heuristic detection
+ first_char = value_content[0] if value_content else ""
+
if first_char.isdigit() or first_char in ["-", "."]:
return "number"
elif first_char in ["{", "["]:
return "object"
+ elif first_char in ['"', "'"]:
+ return "string"
+ # Default to string (safest fallback)
return "string"
def _format_value_complete(self, value: str, value_type: str) -> str:
diff --git a/python/sglang/srt/function_call/glm4_moe_detector.py b/python/sglang/srt/function_call/glm4_moe_detector.py
index 3158e32166c3..0761e24e7cba 100644
--- a/python/sglang/srt/function_call/glm4_moe_detector.py
+++ b/python/sglang/srt/function_call/glm4_moe_detector.py
@@ -2,16 +2,52 @@
import json
import logging
import re
-from typing import List
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
-from sglang.srt.function_call.core_types import StreamingParseResult, _GetInfoFunc
+from sglang.srt.function_call.core_types import (
+ StreamingParseResult,
+ ToolCallItem,
+ _GetInfoFunc,
+)
+from sglang.srt.function_call.utils import infer_type_from_json_schema
logger = logging.getLogger(__name__)
-def get_argument_type(func_name: str, arg_key: str, defined_tools: list):
+class StreamState(str, Enum):
+ """State machine states for XML to JSON streaming conversion."""
+
+ INIT = "INIT"
+ BETWEEN = "BETWEEN"
+ IN_KEY = "IN_KEY"
+ WAITING_VALUE = "WAITING_VALUE"
+ IN_VALUE = "IN_VALUE"
+
+
+def get_argument_type(
+ func_name: str, arg_key: str, defined_tools: List[Tool]
+) -> Optional[str]:
+ """Get the expected type of a function argument from tool definitions.
+
+ Supports complex JSON Schema definitions including:
+ - Direct type field (including type arrays)
+ - anyOf/oneOf: parameter can be any of multiple types
+ - enum: parameter must be one of enum values
+ - allOf: parameter must satisfy all type definitions
+ - properties: inferred as object type
+ - items: inferred as array type
+
+ Args:
+ func_name: Name of the function/tool
+ arg_key: Name of the argument
+ defined_tools: List of available tools
+
+ Returns:
+ The type string (e.g., 'string', 'number', 'object') or None if not found
+ """
name2tool = {tool.function.name: tool for tool in defined_tools}
if func_name not in name2tool:
return None
@@ -21,35 +57,95 @@ def get_argument_type(func_name: str, arg_key: str, defined_tools: list):
properties = {}
if arg_key not in properties:
return None
- return properties[arg_key].get("type", None)
+ # Use new type inference function for complex JSON Schema support
+ return infer_type_from_json_schema(properties[arg_key])
+
+
+def _convert_to_number(value: str) -> Any:
+ """Convert string to appropriate number type (int or float).
+
+ Args:
+ value: String value to convert
+
+ Returns:
+ Converted number or original string if conversion fails
+ """
+ try:
+ if "." in value or "e" in value.lower():
+ return float(value)
+ else:
+ return int(value)
+ except (ValueError, AttributeError):
+ return value
+
+
+def parse_arguments(
+ json_value: str, arg_type: Optional[str] = None
+) -> Tuple[Any, bool]:
+ """Parse argument value with multiple fallback strategies.
+
+ Args:
+ json_value: Raw string value to parse
+ arg_type: Expected type hint ('string', 'number', 'object', etc.)
-def parse_arguments(json_value):
+ Returns:
+ Tuple of (parsed_value, is_valid_json)
+ """
+ # Strategy 1: Direct JSON parsing
try:
parsed_value = json.loads(json_value)
+
+ # Type coercion for number type
+ if arg_type == "number" and isinstance(parsed_value, str):
+ parsed_value = _convert_to_number(parsed_value)
+
return parsed_value, True
- except:
- # If that fails, try wrapping it to unescape JSON characters
- try:
- # Wrap the value as a JSON string field
- wrapped = json.loads('{"tmp": "' + json_value + '"}')
- # parse the unescaped value
- parsed_value = json.loads(wrapped["tmp"])
- return parsed_value, True
- except:
- # Final fallback to ast.literal_eval
- try:
- parsed_value = ast.literal_eval(json_value)
- return parsed_value, True
- except:
- return json_value, False
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ # Strategy 2: Unescape and parse
+ try:
+ wrapped = json.loads('{"tmp": "' + json_value + '"}')
+ parsed_value = json.loads(wrapped["tmp"])
+
+ if arg_type == "number" and isinstance(parsed_value, str):
+ parsed_value = _convert_to_number(parsed_value)
+
+ return parsed_value, True
+ except (json.JSONDecodeError, ValueError, KeyError):
+ pass
+
+ # Strategy 3: ast.literal_eval
+ try:
+ parsed_value = ast.literal_eval(json_value)
+ return parsed_value, True
+ except (ValueError, SyntaxError):
+ pass
+
+ # Strategy 4: Treat as string
+ try:
+ quoted_value = json.dumps(str(json_value))
+ return json.loads(quoted_value), True
+ except (json.JSONDecodeError, ValueError):
+ return json_value, False
class Glm4MoeDetector(BaseFormatDetector):
"""
Detector for GLM-4.5 and GLM-4.6 models.
- Assumes function call format:
- get_weather\ncity\n北京\ndate\n2024-06-27\n\nget_weather\ncity\n上海\ndate\n2024-06-27\n
+ Assumes function call format (with actual newlines):
+ get_weather
+ city
+ 北京
+ date
+ 2024-06-27
+
+
+ Or with literal \n characters (escaped as \\n in the output):
+ get_weather\ncity\n北京\n
+
+ Uses a streaming state machine to convert XML to JSON incrementally for maximum speed.
"""
def __init__(self):
@@ -64,6 +160,23 @@ def __init__(self):
r"(.*?)(?:\\n|\s)*(.*?)",
re.DOTALL,
)
+ self._last_arguments = ""
+ self.current_tool_id = -1
+ self.current_tool_name_sent = False
+ self._streamed_raw_length = 0
+ self._reset_streaming_state()
+
+ def _reset_streaming_state(self) -> None:
+ """Reset the streaming state machine for a new tool call."""
+ self._stream_state = StreamState.INIT
+ self._current_key = ""
+ self._current_value = ""
+ self._xml_tag_buffer = ""
+ self._is_first_param = True
+ self._value_started = False
+ self._cached_value_type: Optional[str] = (
+ None # Cache the value type for consistency
+ )
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a glm-4.5 / glm-4.6 format tool call."""
@@ -92,23 +205,219 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult
func_name = func_detail.group(1) if func_detail.group(1) else ""
func_args = func_detail.group(2) if func_detail.group(2) else ""
pairs = self.func_arg_regex.findall(func_args)
- arguments = {}
- for arg_key, arg_value in pairs:
- arg_key = arg_key.strip()
- arg_value = arg_value.strip()
- arg_type = get_argument_type(func_name, arg_key, tools)
- if arg_type != "string":
- arg_value, is_good_json = parse_arguments(arg_value)
- arguments[arg_key] = arg_value
+
+ # Parse arguments using shared method
+ arguments = self._parse_argument_pairs(pairs, func_name, tools)
+
# construct match_result for parse_base_json
match_result = {"name": func_name, "parameters": arguments}
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
except Exception as e:
- logger.error(f"Error in detect_and_parse: {e}")
+ logger.error(f"Error in detect_and_parse: {e}", exc_info=True)
# return the normal text if parsing fails
return StreamingParseResult(normal_text=text)
+ def _get_value_type(self, func_name: str, key: str, tools: List[Tool]) -> str:
+ """Get parameter type from tool definition, with fallback to auto-detection.
+
+ Args:
+ func_name: Name of the function
+ key: Parameter name
+ tools: List of available tools
+
+ Returns:
+ Type string: 'string', 'number', 'object', 'array', or 'boolean'
+ """
+ arg_type = get_argument_type(func_name, key, tools)
+ if arg_type:
+ return arg_type
+
+ # Improved auto-detection type from value (best effort)
+ value_content = self._current_value.strip() if self._current_value else ""
+
+ if not value_content:
+ return "string"
+
+ # Try to parse as valid JSON first
+ try:
+ parsed = json.loads(value_content)
+ if isinstance(parsed, dict):
+ return "object"
+ elif isinstance(parsed, list):
+ return "array"
+ elif isinstance(parsed, bool):
+ return "boolean"
+ elif isinstance(parsed, (int, float)):
+ return "number"
+ # For string values, check if they look like numbers
+ elif isinstance(parsed, str):
+ if parsed.isdigit() or (
+ parsed.startswith("-") and parsed[1:].isdigit()
+ ):
+ return "number"
+ return "string"
+ except json.JSONDecodeError:
+ # Not valid JSON, try heuristic detection
+ first_char = value_content[0] if value_content else ""
+
+ if first_char.isdigit() or first_char in ["-", "."]:
+ return "number"
+ elif first_char in ["{", "["]:
+ return "object"
+ elif first_char in ['"', "'"]:
+ return "string"
+
+ # Default to string (safest fallback)
+ return "string"
+
+ def _format_value_complete(self, value: str, value_type: str) -> str:
+ """Format complete value based on type.
+
+ Args:
+ value: Raw value string
+ value_type: Expected type ('string', 'number', 'object')
+
+ Returns:
+ Properly formatted JSON value string
+ """
+ if value_type == "string":
+ # Ensure proper JSON string formatting with quotes
+ return json.dumps(value, ensure_ascii=False)
+ elif value_type == "number":
+ try:
+ num = _convert_to_number(value.strip())
+ return str(num)
+ except (ValueError, AttributeError):
+ # Fallback to string if not a valid number
+ logger.warning(
+ f"Failed to parse '{value}' as number, treating as string"
+ )
+ return json.dumps(str(value), ensure_ascii=False)
+ else:
+ # For object/array types, return as-is (should already be valid JSON)
+ return value
+
+ def _process_xml_to_json_streaming(
+ self, raw_increment: str, func_name: str, tools: List[Tool]
+ ) -> str:
+ """Convert XML increment to JSON streaming output using state machine.
+
+ This method processes XML fragments character by character and converts them
+ to JSON format incrementally. It maintains state across calls to handle
+ partial XML tags and values.
+
+ Args:
+ raw_increment: New XML content to process
+ func_name: Name of the function being called
+ tools: List of available tools for type inference
+
+ Returns:
+ JSON string increment to append to the output
+ """
+ json_output = ""
+
+ for char in raw_increment:
+ self._xml_tag_buffer += char
+
+ if self._stream_state in [StreamState.INIT, StreamState.BETWEEN]:
+ if self._xml_tag_buffer.endswith(""):
+ self._stream_state = StreamState.IN_KEY
+ self._current_key = ""
+ self._xml_tag_buffer = ""
+ json_output += "{" if self._is_first_param else ", "
+ self._is_first_param = False
+
+ elif self._stream_state == StreamState.IN_KEY:
+ if self._xml_tag_buffer.endswith(""):
+ self._current_key = self._xml_tag_buffer[:-10].strip()
+ self._xml_tag_buffer = ""
+ self._stream_state = StreamState.WAITING_VALUE
+ json_output += (
+ json.dumps(self._current_key, ensure_ascii=False) + ": "
+ )
+
+ elif self._stream_state == StreamState.WAITING_VALUE:
+ if self._xml_tag_buffer.endswith(""):
+ self._stream_state = StreamState.IN_VALUE
+ self._current_value = ""
+ self._xml_tag_buffer = ""
+ self._value_started = False
+ # Determine and cache the value type at the start
+ self._cached_value_type = self._get_value_type(
+ func_name, self._current_key, tools
+ )
+
+ elif self._stream_state == StreamState.IN_VALUE:
+ if self._xml_tag_buffer.endswith(""):
+ final_value = self._xml_tag_buffer[:-12]
+ self._current_value += final_value
+
+ # Use cached value type for consistency
+ value_type = self._cached_value_type or "string"
+
+ if self._value_started:
+ # Output any remaining content
+ if final_value:
+ if value_type == "string":
+ json_output += json.dumps(
+ final_value, ensure_ascii=False
+ )[1:-1]
+ else:
+ json_output += final_value
+ # Always output closing quote for string type when value was started
+ if value_type == "string":
+ json_output += '"'
+ else:
+ # Value was never started (empty or complete in one chunk)
+ json_output += self._format_value_complete(
+ self._current_value, value_type
+ )
+
+ self._xml_tag_buffer = ""
+ self._stream_state = StreamState.BETWEEN
+ self._current_value = ""
+ self._value_started = False
+ self._cached_value_type = None # Reset cached type
+ else:
+ closing_tag = ""
+ is_potential_closing = len(self._xml_tag_buffer) <= len(
+ closing_tag
+ ) and closing_tag.startswith(self._xml_tag_buffer)
+
+ if not is_potential_closing:
+ content = self._xml_tag_buffer
+ # Use cached value type for consistency
+ value_type = self._cached_value_type or "string"
+
+ if value_type == "string":
+ if not self._value_started:
+ json_output += '"'
+ self._value_started = True
+ if content:
+ json_output += json.dumps(content, ensure_ascii=False)[
+ 1:-1
+ ]
+ self._current_value += content
+ self._xml_tag_buffer = ""
+ elif value_type == "number":
+ if content:
+ if not self._value_started:
+ self._value_started = True
+ json_output += content
+ self._current_value += content
+ self._xml_tag_buffer = ""
+ else:
+ # For object/array types, output as-is
+ if content:
+ if not self._value_started:
+ self._value_started = True
+ json_output += content
+ self._current_value += content
+ self._xml_tag_buffer = ""
+
+ return json_output
+
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
diff --git a/python/sglang/srt/function_call/utils.py b/python/sglang/srt/function_call/utils.py
index d85e5e6c0307..567ca583bb8f 100644
--- a/python/sglang/srt/function_call/utils.py
+++ b/python/sglang/srt/function_call/utils.py
@@ -1,6 +1,6 @@
from json import JSONDecodeError, JSONDecoder
from json.decoder import WHITESPACE
-from typing import Any, List, Literal, Optional, Tuple, Union
+from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import orjson
import partial_json_parser
@@ -101,6 +101,109 @@ def _get_tool_schema(tool: Tool) -> dict:
}
+def infer_type_from_json_schema(schema: Dict[str, Any]) -> Optional[str]:
+ """
+ Infer the primary type of a parameter from JSON Schema.
+
+ Supports complex JSON Schema structures including:
+ - Direct type field (including type arrays)
+ - anyOf/oneOf: parameter can be any of multiple types
+ - enum: parameter must be one of enum values
+ - allOf: parameter must satisfy all type definitions
+ - properties: inferred as object type
+ - items: inferred as array type
+
+ Args:
+ schema: JSON Schema definition
+
+ Returns:
+ Inferred type ('string', 'number', 'object', 'array', etc.) or None
+ """
+ if not isinstance(schema, dict):
+ return None
+
+ # Priority 1: Direct type field (including type arrays)
+ if "type" in schema:
+ type_value = schema["type"]
+ if isinstance(type_value, str):
+ return type_value
+ elif isinstance(type_value, list) and type_value:
+ # Handle type arrays: return first non-null type
+ non_null_types = [t for t in type_value if t != "null"]
+ if non_null_types:
+ return non_null_types[0]
+ return "string" # If only null, default to string
+
+ # Priority 2: Handle anyOf/oneOf
+ if "anyOf" in schema or "oneOf" in schema:
+ schemas = schema.get("anyOf") or schema.get("oneOf")
+ types = []
+
+ if isinstance(schemas, list):
+ for sub_schema in schemas:
+ inferred_type = infer_type_from_json_schema(sub_schema)
+ if inferred_type:
+ types.append(inferred_type)
+
+ if types:
+ # If all types are the same, return unified type
+ if len(set(types)) == 1:
+ return types[0]
+ # When types differ, prioritize string (safest)
+ if "string" in types:
+ return "string"
+ # Otherwise return first type
+ return types[0]
+
+ # Priority 3: Handle enum (infer type from enum values)
+ if "enum" in schema and isinstance(schema["enum"], list):
+ if not schema["enum"]:
+ return "string"
+
+ # Infer type from enum values
+ enum_types = set()
+ for value in schema["enum"]:
+ if value is None:
+ enum_types.add("null")
+ elif isinstance(value, bool):
+ enum_types.add("boolean")
+ elif isinstance(value, int):
+ enum_types.add("integer")
+ elif isinstance(value, float):
+ enum_types.add("number")
+ elif isinstance(value, str):
+ enum_types.add("string")
+ elif isinstance(value, list):
+ enum_types.add("array")
+ elif isinstance(value, dict):
+ enum_types.add("object")
+
+ # If type is uniform, return that type
+ if len(enum_types) == 1:
+ return enum_types.pop()
+ # Mixed types, prioritize string
+ return "string"
+
+ # Priority 4: Handle allOf (must satisfy all types)
+ if "allOf" in schema and isinstance(schema["allOf"], list):
+ schemas = schema["allOf"]
+ for sub_schema in schemas:
+ inferred_type = infer_type_from_json_schema(sub_schema)
+ if inferred_type and inferred_type != "string":
+ return inferred_type
+ return "string"
+
+ # Priority 5: Infer object type
+ if "properties" in schema:
+ return "object"
+
+ # Priority 6: Infer array type
+ if "items" in schema:
+ return "array"
+
+ return None
+
+
def get_json_schema_constraint(
tools: List[Tool], tool_choice: Union[ToolChoice, Literal["required"]]
) -> Optional[dict]:
diff --git a/test/registered/function_call/test_glm47_moe_detector.py b/test/registered/function_call/test_glm47_moe_detector.py
index a046964064ce..fdd06e9ce7df 100644
--- a/test/registered/function_call/test_glm47_moe_detector.py
+++ b/test/registered/function_call/test_glm47_moe_detector.py
@@ -3,7 +3,11 @@
from sglang.srt.entrypoints.openai.protocol import Function, Tool
from sglang.srt.function_call.core_types import StreamingParseResult
-from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector
+from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
+from sglang.srt.function_call.glm47_moe_detector import (
+ Glm47MoeDetector,
+ get_argument_type,
+)
from sglang.test.ci.ci_register import register_cpu_ci
register_cpu_ci(1.0, "default")
@@ -1172,5 +1176,676 @@ def test_streamed_raw_length_multiple_empty_returns(self):
)
-if __name__ == "__main__":
- unittest.main()
+class TestGlm4ComplexJsonSchema(unittest.TestCase):
+ """Test complex JSON Schema type inference for GLM function call parsers."""
+
+ def setUp(self):
+ """Set up test tools with complex JSON schemas."""
+ self.tools_with_complex_schema = [
+ Tool(
+ type="function",
+ function=Function(
+ name="search",
+ description="Search for information",
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "description": "Search query, can be a string or a complex object",
+ "anyOf": [
+ {"type": "string"},
+ {
+ "type": "object",
+ "properties": {
+ "text": {"type": "string"},
+ "filters": {"type": "object"},
+ },
+ },
+ ],
+ },
+ "priority": {"enum": ["low", "medium", "high"]},
+ "options": {
+ "oneOf": [{"type": "string"}, {"type": "number"}]
+ },
+ "config": {
+ "allOf": [
+ {"type": "object"},
+ {"properties": {"timeout": {"type": "number"}}},
+ ]
+ },
+ "tags": {"type": ["string", "null"]},
+ "data": {
+ "type": "object",
+ "properties": {
+ "nested": {
+ "anyOf": [
+ {"type": "string"},
+ {
+ "type": "object",
+ "properties": {
+ "value": {"type": "string"}
+ },
+ },
+ ]
+ }
+ },
+ },
+ },
+ "required": ["query"],
+ },
+ ),
+ ),
+ Tool(
+ type="function",
+ function=Function(
+ name="get_weather",
+ description="Get weather information",
+ parameters={
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "Location to get weather for",
+ },
+ "unit": {
+ "type": "string",
+ "description": "Temperature unit",
+ "enum": ["celsius", "fahrenheit"],
+ },
+ },
+ "required": ["location"],
+ },
+ ),
+ ),
+ ]
+ self.glm4_detector = Glm4MoeDetector()
+ self.glm47_detector = Glm47MoeDetector()
+
+ def test_get_argument_type_simple_type(self):
+ """Test that get_argument_type correctly handles simple type fields."""
+ result = get_argument_type(
+ "get_weather", "location", self.tools_with_complex_schema
+ )
+ self.assertEqual(result, "string")
+
+ def test_get_argument_type_enum_type(self):
+ """Test that get_argument_type correctly identifies enum as string type."""
+ result = get_argument_type(
+ "get_weather", "unit", self.tools_with_complex_schema
+ )
+ # Current implementation returns the direct type field, which is "string" for the enum parameter
+ # But it doesn't handle enum-only schemas properly (without type field)
+ self.assertEqual(result, "string")
+
+ def test_get_argument_type_anyof_type(self):
+ """Test that get_argument_type correctly handles anyOf type fields."""
+ result = get_argument_type("search", "query", self.tools_with_complex_schema)
+ # anyOf with [{"type": "string"}, {"type": "object", ...}] should return "string"
+ self.assertEqual(result, "string") # Returns first common type
+
+ def test_get_argument_type_oneof_type(self):
+ """Test that get_argument_type correctly handles oneOf type fields."""
+ result = get_argument_type("search", "options", self.tools_with_complex_schema)
+ # oneOf with [{"type": "string"}, {"type": "number"}] should return "string" (prioritizes string)
+ self.assertEqual(result, "string")
+
+ def test_get_argument_type_allof_type(self):
+ """Test that get_argument_type correctly handles allOf type fields."""
+ result = get_argument_type("search", "config", self.tools_with_complex_schema)
+ # allOf with [{"type": "object"}, ...] should return "object"
+ self.assertEqual(result, "object")
+
+ def test_get_argument_type_type_array(self):
+ """Test that get_argument_type correctly handles type arrays."""
+ result = get_argument_type("search", "tags", self.tools_with_complex_schema)
+ # Type arrays should return the first non-null type
+ self.assertEqual(
+ result, "string"
+ ) # ["string", "null"] -> "string" (non-null type)
+
+ def test_glm4_detector_with_complex_schema_anyof(self):
+ """Test GLM4 detector with anyOf schema - should demonstrate current issues."""
+ # This test shows the current behavior with complex schemas
+ text = (
+ "search\n"
+ "query\nHello world\n"
+ "priority\nmedium\n"
+ ""
+ )
+ result = self.glm4_detector.detect_and_parse(
+ text, self.tools_with_complex_schema
+ )
+
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "search")
+
+ # Parse parameters to check if they are correctly handled
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["query"], "Hello world")
+ self.assertEqual(params["priority"], "medium")
+
+ def test_glm47_detector_with_complex_schema_anyof(self):
+ """Test GLM47 detector with anyOf schema - should demonstrate current issues."""
+ # This test shows the current behavior with complex schemas
+ text = (
+ "search"
+ "queryHello world"
+ "prioritymedium"
+ ""
+ )
+ result = self.glm47_detector.detect_and_parse(
+ text, self.tools_with_complex_schema
+ )
+
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "search")
+
+ # Parse parameters to check if they are correctly handled
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["query"], "Hello world")
+ self.assertEqual(params["priority"], "medium")
+
+ def test_glm4_detector_with_enum_values(self):
+ """Test GLM4 detector with enum values in complex schema."""
+ text = (
+ "search\n"
+ "query\ntest query\n"
+ "priority\nhigh\n"
+ ""
+ )
+ result = self.glm4_detector.detect_and_parse(
+ text, self.tools_with_complex_schema
+ )
+
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "search")
+
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["query"], "test query")
+ self.assertEqual(params["priority"], "high")
+
+ def test_glm47_detector_with_enum_values(self):
+ """Test GLM47 detector with enum values in complex schema."""
+ text = (
+ "search"
+ "querytest query"
+ "priorityhigh"
+ ""
+ )
+ result = self.glm47_detector.detect_and_parse(
+ text, self.tools_with_complex_schema
+ )
+
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "search")
+
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["query"], "test query")
+ self.assertEqual(params["priority"], "high")
+
+ def test_glm4_detector_streaming_with_complex_schema(self):
+ """Test GLM4 detector streaming with complex schema."""
+ chunks = [
+ "search\n",
+ "query\nnested object\n",
+ "priority\nlow\n",
+ "",
+ ]
+ tool_calls = []
+ for chunk in chunks:
+ result = self.glm4_detector.parse_streaming_increment(
+ chunk, self.tools_with_complex_schema
+ )
+ for tool_call_chunk in result.calls:
+ if (
+ hasattr(tool_call_chunk, "tool_index")
+ and tool_call_chunk.tool_index is not None
+ ):
+ while len(tool_calls) <= tool_call_chunk.tool_index:
+ tool_calls.append({"name": "", "parameters": ""})
+ tc = tool_calls[tool_call_chunk.tool_index]
+ if tool_call_chunk.name:
+ tc["name"] = tool_call_chunk.name
+ if tool_call_chunk.parameters:
+ tc["parameters"] += tool_call_chunk.parameters
+
+ self.assertEqual(len(tool_calls), 1)
+ self.assertEqual(tool_calls[0]["name"], "search")
+
+ params = json.loads(tool_calls[0]["parameters"])
+ self.assertEqual(params["query"], "nested object")
+ self.assertEqual(params["priority"], "low")
+
+ def test_glm47_detector_streaming_with_complex_schema(self):
+ """Test GLM47 detector streaming with complex schema."""
+ chunks = [
+ "search",
+ "querynested object",
+ "prioritylow",
+ "",
+ ]
+ tool_calls = []
+ for chunk in chunks:
+ result = self.glm47_detector.parse_streaming_increment(
+ chunk, self.tools_with_complex_schema
+ )
+ for tool_call_chunk in result.calls:
+ if (
+ hasattr(tool_call_chunk, "tool_index")
+ and tool_call_chunk.tool_index is not None
+ ):
+ while len(tool_calls) <= tool_call_chunk.tool_index:
+ tool_calls.append({"name": "", "parameters": ""})
+ tc = tool_calls[tool_call_chunk.tool_index]
+ if tool_call_chunk.name:
+ tc["name"] = tool_call_chunk.name
+ if tool_call_chunk.parameters:
+ tc["parameters"] += tool_call_chunk.parameters
+
+ self.assertEqual(len(tool_calls), 1)
+ self.assertEqual(tool_calls[0]["name"], "search")
+
+ params = json.loads(tool_calls[0]["parameters"])
+ self.assertEqual(params["query"], "nested object")
+ self.assertEqual(params["priority"], "low")
+
+ def test_type_inference_issue_reproduction(self):
+ """Reproduce the issue where complex JSON schemas are not properly handled."""
+ # This test demonstrates the current limitations
+ complex_tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="complex_function",
+ parameters={
+ "type": "object",
+ "properties": {
+ "complex_param": {
+ "anyOf": [
+ {"type": "string"},
+ {
+ "type": "object",
+ "properties": {"value": {"type": "string"}},
+ },
+ ]
+ },
+ "enum_param": {"enum": ["option1", "option2", "option3"]},
+ },
+ },
+ ),
+ )
+ ]
+
+ # Test that get_argument_type returns appropriate types for complex schemas
+ anyof_result = get_argument_type(
+ "complex_function", "complex_param", complex_tools
+ )
+ enum_result = get_argument_type("complex_function", "enum_param", complex_tools)
+
+ # Verify complex schema types are correctly inferred
+ self.assertEqual(anyof_result, "string") # anyOf prioritizes string type
+ self.assertEqual(enum_result, "string") # enum values are strings
+
+ def test_expected_behavior_for_complex_schemas(self):
+ """Test cases that should work but currently fail - demonstrating the issue."""
+ # This test shows what the behavior SHOULD be after the fix
+ complex_tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="complex_function",
+ parameters={
+ "type": "object",
+ "properties": {
+ "complex_param": {
+ "anyOf": [
+ {"type": "string"},
+ {
+ "type": "object",
+ "properties": {"value": {"type": "string"}},
+ },
+ ]
+ },
+ "enum_param": {"enum": ["option1", "option2", "option3"]},
+ "oneof_param": {
+ "oneOf": [{"type": "string"}, {"type": "number"}]
+ },
+ "allof_param": {
+ "allOf": [
+ {"type": "object"},
+ {"properties": {"timeout": {"type": "number"}}},
+ ]
+ },
+ },
+ },
+ ),
+ )
+ ]
+
+ # These assertions represent the EXPECTED behavior after implementing RFC improvements
+ # Currently they will fail, demonstrating the issue
+ anyof_result = get_argument_type(
+ "complex_function", "complex_param", complex_tools
+ )
+ enum_result = get_argument_type("complex_function", "enum_param", complex_tools)
+ oneof_result = get_argument_type(
+ "complex_function", "oneof_param", complex_tools
+ )
+ allof_result = get_argument_type(
+ "complex_function", "allof_param", complex_tools
+ )
+
+ # These should pass after implementing the RFC improvements, but will currently fail
+ # This demonstrates the issue exists
+ self.assertIsNotNone(
+ anyof_result, "anyOf should return a type after RFC implementation"
+ )
+ self.assertEqual(
+ enum_result,
+ "string",
+ "enum should return 'string' type after RFC implementation",
+ )
+ self.assertIsNotNone(
+ oneof_result, "oneOf should return a type after RFC implementation"
+ )
+ self.assertIsNotNone(
+ allof_result, "allOf should return a type after RFC implementation"
+ )
+
+ def test_complex_schema_type_inference_scenarios(self):
+ """Test various complex schema scenarios mentioned in the RFC."""
+ # Create tools with different complex schema structures
+ complex_schema_tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="search_complex",
+ parameters={
+ "type": "object",
+ "properties": {
+ # anyOf example - parameter can be string or object
+ "query": {
+ "description": "Search query, can be a string or a complex object",
+ "anyOf": [
+ {"type": "string"},
+ {
+ "type": "object",
+ "properties": {
+ "text": {"type": "string"},
+ "filters": {"type": "object"},
+ },
+ },
+ ],
+ },
+ # oneOf example - parameter must be one of the specified types
+ "priority": {
+ "oneOf": [{"type": "string"}, {"type": "integer"}]
+ },
+ # enum example - parameter must be one of the enum values
+ "category": {"enum": ["news", "sports", "tech"]},
+ # allOf example - parameter must satisfy all schemas
+ "config": {
+ "allOf": [
+ {"type": "object"},
+ {"properties": {"timeout": {"type": "number"}}},
+ ]
+ },
+ # Type array example
+ "tags": {"type": ["string", "null"]},
+ },
+ },
+ ),
+ ),
+ Tool(
+ type="function",
+ function=Function(
+ name="get_data",
+ parameters={
+ "type": "object",
+ "properties": {
+ # Complex nested anyOf
+ "input": {
+ "anyOf": [
+ {"type": "string"},
+ {"type": "number"},
+ {
+ "type": "object",
+ "properties": {
+ "type": {"type": "string"},
+ "value": {},
+ },
+ },
+ ]
+ }
+ },
+ },
+ ),
+ ),
+ ]
+
+ # Test each complex type scenario
+ query_type = get_argument_type("search_complex", "query", complex_schema_tools)
+ priority_type = get_argument_type(
+ "search_complex", "priority", complex_schema_tools
+ )
+ category_type = get_argument_type(
+ "search_complex", "category", complex_schema_tools
+ )
+ config_type = get_argument_type(
+ "search_complex", "config", complex_schema_tools
+ )
+ tags_type = get_argument_type("search_complex", "tags", complex_schema_tools)
+ input_type = get_argument_type("get_data", "input", complex_schema_tools)
+
+ # All of these should return appropriate types according to RFC
+ self.assertEqual(query_type, "string") # anyOf: string | object -> string
+ self.assertEqual(priority_type, "string") # oneOf: string | integer -> string
+ self.assertEqual(
+ category_type, "string"
+ ) # enum: ["news", "sports", "tech"] -> string
+ self.assertEqual(config_type, "object") # allOf with object -> object
+ self.assertEqual(
+ tags_type, "string"
+ ) # type array: ["string", "null"] -> string
+ self.assertEqual(
+ input_type, "string"
+ ) # nested anyOf: string | number | object -> string
+
+ def test_glm4_detector_type_handling_with_complex_schema(self):
+ """Test how GLM4 detector handles type inference for complex schemas in practice."""
+ complex_tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="complex_search",
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "anyOf": [
+ {"type": "string"},
+ {
+ "type": "object",
+ "properties": {"text": {"type": "string"}},
+ },
+ ]
+ },
+ "category": {"enum": ["tech", "news", "sports"]},
+ },
+ },
+ ),
+ )
+ ]
+
+ # Test with string value for anyOf parameter
+ text = (
+ "complex_search\n"
+ "query\ntest search\n"
+ "category\ntech\n"
+ ""
+ )
+ result = self.glm4_detector.detect_and_parse(text, complex_tools)
+
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "complex_search")
+
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["query"], "test search")
+ self.assertEqual(params["category"], "tech")
+
+ def test_glm47_detector_type_handling_with_complex_schema(self):
+ """Test how GLM47 detector handles type inference for complex schemas in practice."""
+ complex_tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="complex_search",
+ parameters={
+ "type": "object",
+ "properties": {
+ "query": {
+ "anyOf": [
+ {"type": "string"},
+ {
+ "type": "object",
+ "properties": {"text": {"type": "string"}},
+ },
+ ]
+ },
+ "category": {"enum": ["tech", "news", "sports"]},
+ },
+ },
+ ),
+ )
+ ]
+
+ # Test with string value for anyOf parameter
+ text = (
+ "complex_search"
+ "querytest search"
+ "categorytech"
+ ""
+ )
+ result = self.glm47_detector.detect_and_parse(text, complex_tools)
+
+ self.assertEqual(len(result.calls), 1)
+ self.assertEqual(result.calls[0].name, "complex_search")
+
+ params = json.loads(result.calls[0].parameters)
+ self.assertEqual(params["query"], "test search")
+ self.assertEqual(params["category"], "tech")
+
+ def test_streaming_with_complex_schema_type_inference(self):
+ """Test streaming behavior with complex schema type inference."""
+ complex_tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="stream_test",
+ parameters={
+ "type": "object",
+ "properties": {
+ "data": {
+ "anyOf": [
+ {"type": "string"},
+ {
+ "type": "object",
+ "properties": {"value": {"type": "string"}},
+ },
+ ]
+ },
+ "status": {"enum": ["active", "inactive"]},
+ },
+ },
+ ),
+ )
+ ]
+
+ # Test GLM4 detector streaming
+ chunks = [
+ "stream_test\n",
+ "data\nnested data\n",
+ "status\nactive\n",
+ "",
+ ]
+ tool_calls = []
+ for chunk in chunks:
+ result = self.glm4_detector.parse_streaming_increment(chunk, complex_tools)
+ for tool_call_chunk in result.calls:
+ if (
+ hasattr(tool_call_chunk, "tool_index")
+ and tool_call_chunk.tool_index is not None
+ ):
+ while len(tool_calls) <= tool_call_chunk.tool_index:
+ tool_calls.append({"name": "", "parameters": ""})
+ tc = tool_calls[tool_call_chunk.tool_index]
+ if tool_call_chunk.name:
+ tc["name"] = tool_call_chunk.name
+ if tool_call_chunk.parameters:
+ tc["parameters"] += tool_call_chunk.parameters
+
+ self.assertEqual(len(tool_calls), 1)
+ self.assertEqual(tool_calls[0]["name"], "stream_test")
+
+ params = json.loads(tool_calls[0]["parameters"])
+ self.assertEqual(params["data"], "nested data")
+ self.assertEqual(params["status"], "active")
+
+ def test_streaming_with_complex_schema_type_inference_glm47(self):
+ """Test GLM47 streaming behavior with complex schema type inference."""
+ complex_tools = [
+ Tool(
+ type="function",
+ function=Function(
+ name="stream_test",
+ parameters={
+ "type": "object",
+ "properties": {
+ "data": {
+ "anyOf": [
+ {"type": "string"},
+ {
+ "type": "object",
+ "properties": {"value": {"type": "string"}},
+ },
+ ]
+ },
+ "status": {"enum": ["active", "inactive"]},
+ },
+ },
+ ),
+ )
+ ]
+
+ # Test GLM47 detector streaming
+ chunks = [
+ "stream_test",
+ "datanested data",
+ "statusactive",
+ "",
+ ]
+ tool_calls = []
+ for chunk in chunks:
+ result = self.glm47_detector.parse_streaming_increment(chunk, complex_tools)
+ for tool_call_chunk in result.calls:
+ if (
+ hasattr(tool_call_chunk, "tool_index")
+ and tool_call_chunk.tool_index is not None
+ ):
+ while len(tool_calls) <= tool_call_chunk.tool_index:
+ tool_calls.append({"name": "", "parameters": ""})
+ tc = tool_calls[tool_call_chunk.tool_index]
+ if tool_call_chunk.name:
+ tc["name"] = tool_call_chunk.name
+ if tool_call_chunk.parameters:
+ tc["parameters"] += tool_call_chunk.parameters
+
+ self.assertEqual(len(tool_calls), 1)
+ self.assertEqual(tool_calls[0]["name"], "stream_test")
+
+ params = json.loads(tool_calls[0]["parameters"])
+ self.assertEqual(params["data"], "nested data")
+ self.assertEqual(params["status"], "active")
+
+ if __name__ == "__main__":
+ unittest.main()