Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions src/copaw/local_models/chat_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# pylint:disable=too-many-branches,too-many-statements
"""LocalChatModel ChatModelBase implementation for local backends."""
"""LocalChatModel - ChatModelBase implementation for local backends."""

from __future__ import annotations

Expand Down Expand Up @@ -36,11 +36,22 @@ def _json_loads_safe(s: str) -> dict:
return {}


def _stringify_tool_arguments(arguments: Any) -> str:
if arguments is None:
return ""
if isinstance(arguments, str):
return arguments
try:
return json.dumps(arguments, ensure_ascii=False)
except TypeError:
return ""


class LocalChatModel(ChatModelBase):
"""ChatModelBase implementation for local model backends.

Wraps any ``LocalBackend`` (llama.cpp, future MLX) and presents it
through the agentscope ``ChatModelBase`` interface. Since backends are
through the agentscope ``ChatModelBase`` interface. Since backends are
synchronous, inference runs in a thread executor for async compatibility.
"""

Expand Down Expand Up @@ -158,12 +169,19 @@ def _produce() -> None:
if idx not in tool_calls:
tool_calls[idx] = {
"id": tc.get("id", f"call_{idx}"),
"name": (tc.get("function") or {}).get("name", ""),
"name": "",
"arguments": "",
}
tool_calls[idx]["arguments"] += (tc.get("function") or {}).get(
"arguments",
) or ""
if tc.get("id"):
tool_calls[idx]["id"] = tc["id"]
Comment on lines 169 to +176
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic for handling tool call IDs can result in an empty string ID if the model provides id: "". This is inconsistent with the logic in tag_parser.py, which generates a new ID in such cases. An empty ID might cause issues downstream. The logic can also be simplified to be more readable and robust.

Suggested change
if idx not in tool_calls:
tool_calls[idx] = {
"id": tc.get("id", f"call_{idx}"),
"name": (tc.get("function") or {}).get("name", ""),
"name": "",
"arguments": "",
}
tool_calls[idx]["arguments"] += (tc.get("function") or {}).get(
"arguments",
) or ""
if tc.get("id"):
tool_calls[idx]["id"] = tc["id"]
if idx not in tool_calls:
tool_calls[idx] = {
"id": f"call_{idx}",
"name": "",
"arguments": "",
}
if tc.get("id"):
tool_calls[idx]["id"] = tc["id"]


function_data = tc.get("function") or {}
if function_data.get("name"):
tool_calls[idx]["name"] = function_data["name"]

tool_calls[idx]["arguments"] += _stringify_tool_arguments(
function_data.get("arguments"),
)

# Build content blocks
contents: list = []
Expand Down Expand Up @@ -231,6 +249,8 @@ def _produce() -> None:
)

for tc_data in tool_calls.values():
if not tc_data["name"]:
continue
contents.append(
ToolUseBlock(
type="tool_use",
Expand Down
66 changes: 47 additions & 19 deletions src/copaw/local_models/tag_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,28 @@ def _generate_call_id() -> str:
return f"call_{uuid.uuid4().hex[:12]}"


def _normalize_tool_arguments(arguments: object) -> tuple[dict, str]:
if isinstance(arguments, str):
try:
parsed = json.loads(arguments)
except (json.JSONDecodeError, TypeError):
return {}, arguments
return (parsed if isinstance(parsed, dict) else {}), arguments

if isinstance(arguments, dict):
return arguments, json.dumps(arguments, ensure_ascii=False)
Comment on lines +103 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The json.dumps call for dictionary arguments is not wrapped in a try-except block. If the dictionary contains non-serializable types (e.g., datetime objects), this will raise an unhandled TypeError, causing the application to crash. This is inconsistent with the error handling for other types within the same function.

Suggested change
if isinstance(arguments, dict):
return arguments, json.dumps(arguments, ensure_ascii=False)
if isinstance(arguments, dict):
try:
return arguments, json.dumps(arguments, ensure_ascii=False)
except TypeError:
return arguments, ""


if arguments is None:
return {}, ""

try:
raw_arguments = json.dumps(arguments, ensure_ascii=False)
except TypeError:
raw_arguments = ""

return {}, raw_arguments


def _parse_single_tool_call(raw_text: str) -> ParsedToolCall | None:
"""
Parse the JSON content between a ``<tool_call>`` / ``</tool_call>`` pair.
Expand All @@ -106,23 +128,29 @@ def _parse_single_tool_call(raw_text: str) -> ParsedToolCall | None:
logger.warning("Failed to parse tool call JSON: %s", raw_text[:200])
return None

name = data.get("name", "")
if not isinstance(data, dict):
logger.warning("Tool call JSON must decode to an object: %s", raw_text[:200])
return None

function_data = data.get("function")
if isinstance(function_data, dict):
name = function_data.get("name", "")
arguments_value = function_data.get("arguments", {})
else:
name = data.get("name", "")
arguments_value = data.get("arguments", {})

if not name:
logger.warning("Tool call missing 'name' field: %s", raw_text[:200])
return None

arguments = data.get("arguments", {})
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except (json.JSONDecodeError, TypeError):
arguments = {}
arguments, raw_arguments = _normalize_tool_arguments(arguments_value)

return ParsedToolCall(
id=_generate_call_id(),
id=data.get("id") or _generate_call_id(),
name=name,
arguments=arguments,
raw_arguments=json.dumps(arguments, ensure_ascii=False),
raw_arguments=raw_arguments,
)


Expand All @@ -141,9 +169,9 @@ def extract_thinking_from_text(text: str) -> TextWithThinking:

Returns a :class:`TextWithThinking` with:

* ``thinking`` the reasoning content (empty if none found)
* ``remaining_text`` everything outside the think tags
* ``has_open_tag`` ``True`` if ``<think>`` opened but not closed yet
* ``thinking`` the reasoning content (empty if none found)
* ``remaining_text`` everything outside the think tags
* ``has_open_tag`` ``True`` if ``<think>`` opened but not closed yet
"""
match = _THINK_RE.search(text)
if match:
Expand All @@ -154,7 +182,7 @@ def extract_thinking_from_text(text: str) -> TextWithThinking:
remaining_text=remaining,
)

# No complete block check for an unclosed <think>.
# No complete block; check for an unclosed <think>.
open_idx = text.find(THINK_START)
if open_idx != -1:
remaining = text[:open_idx].strip()
Expand All @@ -178,17 +206,17 @@ def parse_tool_calls_from_text(text: str) -> TextWithToolCalls:

Returns a :class:`TextWithToolCalls` with:

* ``text_before`` all text before the first ``<tool_call>`` tag
* ``text_after`` all text after the last ``</tool_call>`` tag
* ``tool_calls`` successfully parsed tool calls
* ``has_open_tag`` whether there is an unclosed ``<tool_call>``
* ``text_before`` all text before the first ``<tool_call>`` tag
* ``text_after`` all text after the last ``</tool_call>`` tag
* ``tool_calls`` successfully parsed tool calls
* ``has_open_tag`` whether there is an unclosed ``<tool_call>``
(streaming)
* ``partial_tool_text`` content after the unclosed tag
* ``partial_tool_text`` content after the unclosed tag
"""
matches = list(_TOOL_CALL_RE.finditer(text))

if not matches:
# No complete blocks. Check for an unclosed opening tag.
# No complete blocks. Check for an unclosed opening tag.
open_idx = text.rfind(TOOL_CALL_START)
if open_idx != -1:
return TextWithToolCalls(
Expand Down
142 changes: 142 additions & 0 deletions tests/unit/local_models/test_local_model_tool_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from __future__ import annotations

import asyncio
import json
from datetime import datetime
from typing import Any

from copaw.local_models.backends.base import LocalBackend
from copaw.local_models.chat_model import LocalChatModel
from copaw.local_models.tag_parser import parse_tool_calls_from_text


class DummyLocalBackend(LocalBackend):
def __init__(
self,
model_path: str = "",
*,
stream_chunks: list[dict[str, Any]] | None = None,
**_: Any,
) -> None:
self._stream_chunks = stream_chunks or []
self._loaded = True

def chat_completion(
self,
messages: list[dict],
tools: list[dict] | None = None,
tool_choice: str | None = None,
structured_model: Any = None,
**kwargs: Any,
) -> dict:
return {"choices": [], "usage": None}

def chat_completion_stream(
self,
messages: list[dict],
tools: list[dict] | None = None,
tool_choice: str | None = None,
**kwargs: Any,
):
yield from self._stream_chunks

def unload(self) -> None:
self._loaded = False

@property
def is_loaded(self) -> bool:
return self._loaded


def _make_stream_chunk(tool_calls: list[dict[str, Any]]) -> dict[str, Any]:
return {
"choices": [
{
"delta": {
"content": None,
"reasoning_content": None,
"tool_calls": tool_calls,
},
},
],
}


def test_parse_tool_calls_from_text_supports_openai_function_format() -> None:
tool_call = {
"id": "call_abc123",
"type": "function",
"function": {
"name": "execute_shell_command",
"arguments": json.dumps({"command": "ls -la"}),
},
}
text = (
"prefix\n"
f"<tool_call>\n{json.dumps(tool_call)}\n</tool_call>\n"
"suffix"
)

parsed = parse_tool_calls_from_text(text)

assert parsed.text_before == "prefix"
assert parsed.text_after == "suffix"
assert len(parsed.tool_calls) == 1
assert parsed.tool_calls[0].id == "call_abc123"
assert parsed.tool_calls[0].name == "execute_shell_command"
assert parsed.tool_calls[0].arguments == {"command": "ls -la"}
assert parsed.tool_calls[0].raw_arguments == "{\"command\": \"ls -la\"}"


def test_stream_response_waits_for_non_empty_tool_name() -> None:
backend = DummyLocalBackend(
stream_chunks=[
_make_stream_chunk(
[
{
"index": 0,
"id": "call_stream",
"function": {"arguments": '{"command": '},
},
],
),
_make_stream_chunk(
[
{
"index": 0,
"function": {
"name": "execute_shell_command",
"arguments": '"ls -la"}',
},
},
],
),
],
)
model = LocalChatModel("dummy", backend, stream=True)

async def _collect_responses() -> list[Any]:
responses = []
async for response in model._stream_response(
messages=[],
tools=None,
tool_choice=None,
start_datetime=datetime.now(),
):
responses.append(response)
return responses

responses = asyncio.run(_collect_responses())

tool_blocks = [
block
for response in responses
for block in response.content
if block.get("type") == "tool_use"
]

assert tool_blocks
assert [block["name"] for block in tool_blocks] == ["execute_shell_command"]
assert tool_blocks[0]["id"] == "call_stream"
assert tool_blocks[0]["input"] == {"command": "ls -la"}
assert tool_blocks[0]["raw_input"] == '{"command": "ls -la"}'
Loading