Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions src/copaw/agents/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,27 @@ def _get_formatter_for_chat_model(
)


def _assistant_message_survives_format(msg: Msg) -> bool:
"""Best-effort check for assistant messages kept by AgentScope formatters."""
if msg.role != "assistant":
return False

if isinstance(msg.content, str):
return bool(msg.content)

for block in msg.get_content_blocks():
block_type = block.get("type")
if block_type == "thinking":
continue
if block_type == "text":
if block.get("text"):
return True
continue
return True
Comment on lines +114 to +120
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic within this for loop can be simplified for improved readability and maintainability. By restructuring the conditions, you can make the code more concise while preserving the existing functionality.

Suggested change
if block_type == "thinking":
continue
if block_type == "text":
if block.get("text"):
return True
continue
return True
if block_type == "thinking":
continue
if block_type == "text" and not block.get("text"):
continue
# Any other block type, or a text block with content, survives.
return True


return False


def _create_file_block_support_formatter(
base_formatter_class: Type[FormatterBase],
) -> Type[FormatterBase]:
Expand Down Expand Up @@ -160,26 +181,19 @@ async def _format(self, msgs):
tc["extra_content"] = ec

if reasoning_contents:
in_assistant = [m for m in msgs if m.role == "assistant"]
in_assistant = [
m for m in msgs if _assistant_message_survives_format(m)
]
out_assistant = [
m for m in messages if m.get("role") == "assistant"
]
if len(in_assistant) != len(out_assistant):
logger.warning(
"Assistant message count mismatch after formatting "
"(%d before, %d after). "
"Skipping reasoning_content injection.",
len(in_assistant),
len(out_assistant),
)
else:
for in_msg, out_msg in zip(
in_assistant,
out_assistant,
):
reasoning = reasoning_contents.get(id(in_msg))
if reasoning:
out_msg["reasoning_content"] = reasoning
for in_msg, out_msg in zip(
in_assistant,
out_assistant,
):
reasoning = reasoning_contents.get(id(in_msg))
if reasoning:
out_msg["reasoning_content"] = reasoning
Comment on lines +184 to +196
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Removing the length check between the predicted surviving messages and the actual formatted messages could hide future bugs. Since _assistant_message_survives_format is a 'best-effort' check, if the underlying formatter's logic changes, this could lead to silent misalignment or dropping of reasoning_content. It would be safer to reintroduce a warning log if a mismatch occurs. This would provide a valuable signal for debugging if the prediction logic becomes inaccurate over time.

                in_assistant = [
                    m for m in msgs if _assistant_message_survives_format(m)
                ]
                out_assistant = [
                    m for m in messages if m.get("role") == "assistant"
                ]
                if len(in_assistant) != len(out_assistant):
                    logger.warning(
                        "Assistant message count mismatch after formatting "
                        "(%d predicted, %d actual). "
                        "Reasoning content may be misaligned.",
                        len(in_assistant),
                        len(out_assistant),
                    )
                for in_msg, out_msg in zip(
                    in_assistant,
                    out_assistant,
                ):
                    reasoning = reasoning_contents.get(id(in_msg))
                    if reasoning:
                        out_msg["reasoning_content"] = reasoning


return _strip_top_level_message_name(messages)

Expand Down
207 changes: 207 additions & 0 deletions tests/unit/agents/test_model_factory_reasoning_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import importlib.util
import sys
import types
from pathlib import Path

import pytest


MODEL_FACTORY_PATH = Path(__file__).resolve().parents[3] / "src" / "copaw" / "agents" / "model_factory.py"


class FakeMsg:
def __init__(self, role: str, blocks: list[dict] | str):
self.role = role
self.content = blocks

def get_content_blocks(self) -> list[dict]:
if isinstance(self.content, list):
return list(self.content)
if not self.content:
return []
return [{"type": "text", "text": self.content}]


@pytest.fixture
def anyio_backend() -> str:
return "asyncio"


@pytest.fixture
def model_factory_module(monkeypatch: pytest.MonkeyPatch):
for name in list(sys.modules):
if name.startswith("copaw.agents.model_factory"):
sys.modules.pop(name, None)

copaw_pkg = types.ModuleType("copaw")
copaw_pkg.__path__ = []
agents_pkg = types.ModuleType("copaw.agents")
agents_pkg.__path__ = []
utils_pkg = types.ModuleType("copaw.agents.utils")
utils_pkg.__path__ = []
providers_pkg = types.ModuleType("copaw.providers")
providers_pkg.__path__ = []

monkeypatch.setitem(sys.modules, "copaw", copaw_pkg)
monkeypatch.setitem(sys.modules, "copaw.agents", agents_pkg)
monkeypatch.setitem(sys.modules, "copaw.agents.utils", utils_pkg)
monkeypatch.setitem(sys.modules, "copaw.providers", providers_pkg)

tool_utils = types.ModuleType("copaw.agents.utils.tool_message_utils")
tool_utils._sanitize_tool_messages = lambda msgs: msgs
monkeypatch.setitem(
sys.modules,
"copaw.agents.utils.tool_message_utils",
tool_utils,
)

providers_pkg.ProviderManager = object
retry_model = types.ModuleType("copaw.providers.retry_chat_model")
retry_model.RetryChatModel = object
monkeypatch.setitem(
sys.modules,
"copaw.providers.retry_chat_model",
retry_model,
)

token_usage = types.ModuleType("copaw.token_usage")
token_usage.TokenRecordingModelWrapper = object
monkeypatch.setitem(sys.modules, "copaw.token_usage", token_usage)

agentscope = types.ModuleType("agentscope")
agentscope.__version__ = "1.0.16"
formatter_mod = types.ModuleType("agentscope.formatter")
model_mod = types.ModuleType("agentscope.model")
message_mod = types.ModuleType("agentscope.message")

class FormatterBase:
pass

class OpenAIChatFormatter(FormatterBase):
async def format(self, msgs, **kwargs):
return await self._format(msgs, **kwargs)

class ChatModelBase:
pass

class OpenAIChatModel:
pass

class Msg:
pass

formatter_mod.FormatterBase = FormatterBase
formatter_mod.OpenAIChatFormatter = OpenAIChatFormatter
model_mod.ChatModelBase = ChatModelBase
model_mod.OpenAIChatModel = OpenAIChatModel
message_mod.Msg = Msg

monkeypatch.setitem(sys.modules, "agentscope", agentscope)
monkeypatch.setitem(sys.modules, "agentscope.formatter", formatter_mod)
monkeypatch.setitem(sys.modules, "agentscope.model", model_mod)
monkeypatch.setitem(sys.modules, "agentscope.message", message_mod)

spec = importlib.util.spec_from_file_location(
"copaw.agents.model_factory",
MODEL_FACTORY_PATH,
)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, "copaw.agents.model_factory", module)
spec.loader.exec_module(module)
return module


@pytest.mark.anyio
async def test_reasoning_content_skips_dropped_thinking_only_messages(
model_factory_module,
caplog: pytest.LogCaptureFixture,
):
class BaseFormatter:
async def _format(self, msgs):
messages = []
for msg in msgs:
blocks = msg.get_content_blocks()
text = " ".join(
block.get("text", "")
for block in blocks
if block.get("type") == "text" and block.get("text")
)
if msg.role == "assistant" and text:
messages.append({"role": "assistant", "content": text})
return messages

formatter = model_factory_module._create_file_block_support_formatter(
BaseFormatter
)()

messages = [
FakeMsg("assistant", [{"type": "thinking", "thinking": "discard me"}]),
FakeMsg(
"assistant",
[
{"type": "thinking", "thinking": "keep me"},
{"type": "text", "text": "visible answer"},
],
),
]

with caplog.at_level("WARNING"):
formatted = await formatter._format(messages)

assert formatted == [
{
"role": "assistant",
"content": "visible answer",
"reasoning_content": "keep me",
}
]
assert "Assistant message count mismatch after formatting" not in caplog.text


@pytest.mark.anyio
async def test_reasoning_content_stays_aligned_with_surviving_assistant_messages(
model_factory_module,
):
class BaseFormatter:
async def _format(self, msgs):
messages = []
for msg in msgs:
blocks = msg.get_content_blocks()
text = " ".join(
block.get("text", "")
for block in blocks
if block.get("type") == "text" and block.get("text")
)
if msg.role == "assistant" and text:
messages.append({"role": "assistant", "content": text})
return messages

formatter = model_factory_module._create_file_block_support_formatter(
BaseFormatter
)()

formatted = await formatter._format(
[
FakeMsg("assistant", [{"type": "text", "text": "plain"}]),
FakeMsg(
"assistant",
[
{"type": "thinking", "thinking": "reasoned"},
{"type": "text", "text": "final"},
],
),
]
)

assert formatted == [
{"role": "assistant", "content": "plain"},
{
"role": "assistant",
"content": "final",
"reasoning_content": "reasoned",
},
]
Loading