-
Notifications
You must be signed in to change notification settings - Fork 1.5k
fix(local-models): support OpenAI-style tool calls #1512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -92,6 +92,28 @@ def _generate_call_id() -> str: | |||||||||||||||
| return f"call_{uuid.uuid4().hex[:12]}" | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def _normalize_tool_arguments(arguments: object) -> tuple[dict, str]: | ||||||||||||||||
| if isinstance(arguments, str): | ||||||||||||||||
| try: | ||||||||||||||||
| parsed = json.loads(arguments) | ||||||||||||||||
| except (json.JSONDecodeError, TypeError): | ||||||||||||||||
| return {}, arguments | ||||||||||||||||
| return (parsed if isinstance(parsed, dict) else {}), arguments | ||||||||||||||||
|
|
||||||||||||||||
| if isinstance(arguments, dict): | ||||||||||||||||
| return arguments, json.dumps(arguments, ensure_ascii=False) | ||||||||||||||||
|
Comment on lines
+103
to
+104
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| if arguments is None: | ||||||||||||||||
| return {}, "" | ||||||||||||||||
|
|
||||||||||||||||
| try: | ||||||||||||||||
| raw_arguments = json.dumps(arguments, ensure_ascii=False) | ||||||||||||||||
| except TypeError: | ||||||||||||||||
| raw_arguments = "" | ||||||||||||||||
|
|
||||||||||||||||
| return {}, raw_arguments | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def _parse_single_tool_call(raw_text: str) -> ParsedToolCall | None: | ||||||||||||||||
| """ | ||||||||||||||||
| Parse the JSON content between a ``<tool_call>`` / ``</tool_call>`` pair. | ||||||||||||||||
|
|
@@ -106,23 +128,29 @@ def _parse_single_tool_call(raw_text: str) -> ParsedToolCall | None: | |||||||||||||||
| logger.warning("Failed to parse tool call JSON: %s", raw_text[:200]) | ||||||||||||||||
| return None | ||||||||||||||||
|
|
||||||||||||||||
| name = data.get("name", "") | ||||||||||||||||
| if not isinstance(data, dict): | ||||||||||||||||
| logger.warning("Tool call JSON must decode to an object: %s", raw_text[:200]) | ||||||||||||||||
| return None | ||||||||||||||||
|
|
||||||||||||||||
| function_data = data.get("function") | ||||||||||||||||
| if isinstance(function_data, dict): | ||||||||||||||||
| name = function_data.get("name", "") | ||||||||||||||||
| arguments_value = function_data.get("arguments", {}) | ||||||||||||||||
| else: | ||||||||||||||||
| name = data.get("name", "") | ||||||||||||||||
| arguments_value = data.get("arguments", {}) | ||||||||||||||||
|
|
||||||||||||||||
| if not name: | ||||||||||||||||
| logger.warning("Tool call missing 'name' field: %s", raw_text[:200]) | ||||||||||||||||
| return None | ||||||||||||||||
|
|
||||||||||||||||
| arguments = data.get("arguments", {}) | ||||||||||||||||
| if isinstance(arguments, str): | ||||||||||||||||
| try: | ||||||||||||||||
| arguments = json.loads(arguments) | ||||||||||||||||
| except (json.JSONDecodeError, TypeError): | ||||||||||||||||
| arguments = {} | ||||||||||||||||
| arguments, raw_arguments = _normalize_tool_arguments(arguments_value) | ||||||||||||||||
|
|
||||||||||||||||
| return ParsedToolCall( | ||||||||||||||||
| id=_generate_call_id(), | ||||||||||||||||
| id=data.get("id") or _generate_call_id(), | ||||||||||||||||
| name=name, | ||||||||||||||||
| arguments=arguments, | ||||||||||||||||
| raw_arguments=json.dumps(arguments, ensure_ascii=False), | ||||||||||||||||
| raw_arguments=raw_arguments, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -141,9 +169,9 @@ def extract_thinking_from_text(text: str) -> TextWithThinking: | |||||||||||||||
|
|
||||||||||||||||
| Returns a :class:`TextWithThinking` with: | ||||||||||||||||
|
|
||||||||||||||||
| * ``thinking`` – the reasoning content (empty if none found) | ||||||||||||||||
| * ``remaining_text`` – everything outside the think tags | ||||||||||||||||
| * ``has_open_tag`` – ``True`` if ``<think>`` opened but not closed yet | ||||||||||||||||
| * ``thinking`` the reasoning content (empty if none found) | ||||||||||||||||
| * ``remaining_text`` everything outside the think tags | ||||||||||||||||
| * ``has_open_tag`` ``True`` if ``<think>`` opened but not closed yet | ||||||||||||||||
| """ | ||||||||||||||||
| match = _THINK_RE.search(text) | ||||||||||||||||
| if match: | ||||||||||||||||
|
|
@@ -154,7 +182,7 @@ def extract_thinking_from_text(text: str) -> TextWithThinking: | |||||||||||||||
| remaining_text=remaining, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| # No complete block — check for an unclosed <think>. | ||||||||||||||||
| # No complete block; check for an unclosed <think>. | ||||||||||||||||
| open_idx = text.find(THINK_START) | ||||||||||||||||
| if open_idx != -1: | ||||||||||||||||
| remaining = text[:open_idx].strip() | ||||||||||||||||
|
|
@@ -178,17 +206,17 @@ def parse_tool_calls_from_text(text: str) -> TextWithToolCalls: | |||||||||||||||
|
|
||||||||||||||||
| Returns a :class:`TextWithToolCalls` with: | ||||||||||||||||
|
|
||||||||||||||||
| * ``text_before`` – all text before the first ``<tool_call>`` tag | ||||||||||||||||
| * ``text_after`` – all text after the last ``</tool_call>`` tag | ||||||||||||||||
| * ``tool_calls`` – successfully parsed tool calls | ||||||||||||||||
| * ``has_open_tag`` – whether there is an unclosed ``<tool_call>`` | ||||||||||||||||
| * ``text_before`` all text before the first ``<tool_call>`` tag | ||||||||||||||||
| * ``text_after`` all text after the last ``</tool_call>`` tag | ||||||||||||||||
| * ``tool_calls`` successfully parsed tool calls | ||||||||||||||||
| * ``has_open_tag`` whether there is an unclosed ``<tool_call>`` | ||||||||||||||||
| (streaming) | ||||||||||||||||
| * ``partial_tool_text`` – content after the unclosed tag | ||||||||||||||||
| * ``partial_tool_text`` content after the unclosed tag | ||||||||||||||||
| """ | ||||||||||||||||
| matches = list(_TOOL_CALL_RE.finditer(text)) | ||||||||||||||||
|
|
||||||||||||||||
| if not matches: | ||||||||||||||||
| # No complete blocks. Check for an unclosed opening tag. | ||||||||||||||||
| # No complete blocks. Check for an unclosed opening tag. | ||||||||||||||||
| open_idx = text.rfind(TOOL_CALL_START) | ||||||||||||||||
| if open_idx != -1: | ||||||||||||||||
| return TextWithToolCalls( | ||||||||||||||||
|
|
||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import json | ||
| from datetime import datetime | ||
| from typing import Any | ||
|
|
||
| from copaw.local_models.backends.base import LocalBackend | ||
| from copaw.local_models.chat_model import LocalChatModel | ||
| from copaw.local_models.tag_parser import parse_tool_calls_from_text | ||
|
|
||
|
|
||
| class DummyLocalBackend(LocalBackend): | ||
| def __init__( | ||
| self, | ||
| model_path: str = "", | ||
| *, | ||
| stream_chunks: list[dict[str, Any]] | None = None, | ||
| **_: Any, | ||
| ) -> None: | ||
| self._stream_chunks = stream_chunks or [] | ||
| self._loaded = True | ||
|
|
||
| def chat_completion( | ||
| self, | ||
| messages: list[dict], | ||
| tools: list[dict] | None = None, | ||
| tool_choice: str | None = None, | ||
| structured_model: Any = None, | ||
| **kwargs: Any, | ||
| ) -> dict: | ||
| return {"choices": [], "usage": None} | ||
|
|
||
| def chat_completion_stream( | ||
| self, | ||
| messages: list[dict], | ||
| tools: list[dict] | None = None, | ||
| tool_choice: str | None = None, | ||
| **kwargs: Any, | ||
| ): | ||
| yield from self._stream_chunks | ||
|
|
||
| def unload(self) -> None: | ||
| self._loaded = False | ||
|
|
||
| @property | ||
| def is_loaded(self) -> bool: | ||
| return self._loaded | ||
|
|
||
|
|
||
| def _make_stream_chunk(tool_calls: list[dict[str, Any]]) -> dict[str, Any]: | ||
| return { | ||
| "choices": [ | ||
| { | ||
| "delta": { | ||
| "content": None, | ||
| "reasoning_content": None, | ||
| "tool_calls": tool_calls, | ||
| }, | ||
| }, | ||
| ], | ||
| } | ||
|
|
||
|
|
||
| def test_parse_tool_calls_from_text_supports_openai_function_format() -> None: | ||
| tool_call = { | ||
| "id": "call_abc123", | ||
| "type": "function", | ||
| "function": { | ||
| "name": "execute_shell_command", | ||
| "arguments": json.dumps({"command": "ls -la"}), | ||
| }, | ||
| } | ||
| text = ( | ||
| "prefix\n" | ||
| f"<tool_call>\n{json.dumps(tool_call)}\n</tool_call>\n" | ||
| "suffix" | ||
| ) | ||
|
|
||
| parsed = parse_tool_calls_from_text(text) | ||
|
|
||
| assert parsed.text_before == "prefix" | ||
| assert parsed.text_after == "suffix" | ||
| assert len(parsed.tool_calls) == 1 | ||
| assert parsed.tool_calls[0].id == "call_abc123" | ||
| assert parsed.tool_calls[0].name == "execute_shell_command" | ||
| assert parsed.tool_calls[0].arguments == {"command": "ls -la"} | ||
| assert parsed.tool_calls[0].raw_arguments == "{\"command\": \"ls -la\"}" | ||
|
|
||
|
|
||
| def test_stream_response_waits_for_non_empty_tool_name() -> None: | ||
| backend = DummyLocalBackend( | ||
| stream_chunks=[ | ||
| _make_stream_chunk( | ||
| [ | ||
| { | ||
| "index": 0, | ||
| "id": "call_stream", | ||
| "function": {"arguments": '{"command": '}, | ||
| }, | ||
| ], | ||
| ), | ||
| _make_stream_chunk( | ||
| [ | ||
| { | ||
| "index": 0, | ||
| "function": { | ||
| "name": "execute_shell_command", | ||
| "arguments": '"ls -la"}', | ||
| }, | ||
| }, | ||
| ], | ||
| ), | ||
| ], | ||
| ) | ||
| model = LocalChatModel("dummy", backend, stream=True) | ||
|
|
||
| async def _collect_responses() -> list[Any]: | ||
| responses = [] | ||
| async for response in model._stream_response( | ||
| messages=[], | ||
| tools=None, | ||
| tool_choice=None, | ||
| start_datetime=datetime.now(), | ||
| ): | ||
| responses.append(response) | ||
| return responses | ||
|
|
||
| responses = asyncio.run(_collect_responses()) | ||
|
|
||
| tool_blocks = [ | ||
| block | ||
| for response in responses | ||
| for block in response.content | ||
| if block.get("type") == "tool_use" | ||
| ] | ||
|
|
||
| assert tool_blocks | ||
| assert [block["name"] for block in tool_blocks] == ["execute_shell_command"] | ||
| assert tool_blocks[0]["id"] == "call_stream" | ||
| assert tool_blocks[0]["input"] == {"command": "ls -la"} | ||
| assert tool_blocks[0]["raw_input"] == '{"command": "ls -la"}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current logic for handling tool call IDs can result in an empty string ID if the model provides
id: "". This is inconsistent with the logic intag_parser.py, which generates a new ID in such cases. An empty ID might cause issues downstream. The logic can also be simplified to be more readable and robust.