-
Notifications
You must be signed in to change notification settings - Fork 1.5k
fix(formatter): preserve reasoning content on surviving assistant turns #1538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -101,6 +101,27 @@ def _get_formatter_for_chat_model( | |
| ) | ||
|
|
||
|
|
||
| def _assistant_message_survives_format(msg: Msg) -> bool: | ||
| """Best-effort check for assistant messages kept by AgentScope formatters.""" | ||
| if msg.role != "assistant": | ||
| return False | ||
|
|
||
| if isinstance(msg.content, str): | ||
| return bool(msg.content) | ||
|
|
||
| for block in msg.get_content_blocks(): | ||
| block_type = block.get("type") | ||
| if block_type == "thinking": | ||
| continue | ||
| if block_type == "text": | ||
| if block.get("text"): | ||
| return True | ||
| continue | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
|
|
||
| def _create_file_block_support_formatter( | ||
| base_formatter_class: Type[FormatterBase], | ||
| ) -> Type[FormatterBase]: | ||
|
|
@@ -160,26 +181,19 @@ async def _format(self, msgs): | |
| tc["extra_content"] = ec | ||
|
|
||
| if reasoning_contents: | ||
| in_assistant = [m for m in msgs if m.role == "assistant"] | ||
| in_assistant = [ | ||
| m for m in msgs if _assistant_message_survives_format(m) | ||
| ] | ||
| out_assistant = [ | ||
| m for m in messages if m.get("role") == "assistant" | ||
| ] | ||
| if len(in_assistant) != len(out_assistant): | ||
| logger.warning( | ||
| "Assistant message count mismatch after formatting " | ||
| "(%d before, %d after). " | ||
| "Skipping reasoning_content injection.", | ||
| len(in_assistant), | ||
| len(out_assistant), | ||
| ) | ||
| else: | ||
| for in_msg, out_msg in zip( | ||
| in_assistant, | ||
| out_assistant, | ||
| ): | ||
| reasoning = reasoning_contents.get(id(in_msg)) | ||
| if reasoning: | ||
| out_msg["reasoning_content"] = reasoning | ||
| for in_msg, out_msg in zip( | ||
| in_assistant, | ||
| out_assistant, | ||
| ): | ||
| reasoning = reasoning_contents.get(id(in_msg)) | ||
| if reasoning: | ||
| out_msg["reasoning_content"] = reasoning | ||
|
Comment on lines
+184
to
+196
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing the length check between the predicted surviving messages and the actual formatted messages could hide future bugs. Since in_assistant = [
m for m in msgs if _assistant_message_survives_format(m)
]
out_assistant = [
m for m in messages if m.get("role") == "assistant"
]
if len(in_assistant) != len(out_assistant):
logger.warning(
"Assistant message count mismatch after formatting "
"(%d predicted, %d actual). "
"Reasoning content may be misaligned.",
len(in_assistant),
len(out_assistant),
)
for in_msg, out_msg in zip(
in_assistant,
out_assistant,
):
reasoning = reasoning_contents.get(id(in_msg))
if reasoning:
out_msg["reasoning_content"] = reasoning |
||
|
|
||
| return _strip_top_level_message_name(messages) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,207 @@ | ||
| # -*- coding: utf-8 -*- | ||
| from __future__ import annotations | ||
|
|
||
| import importlib.util | ||
| import sys | ||
| import types | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
|
|
||
|
|
||
| MODEL_FACTORY_PATH = Path(__file__).resolve().parents[3] / "src" / "copaw" / "agents" / "model_factory.py" | ||
|
|
||
|
|
||
| class FakeMsg: | ||
| def __init__(self, role: str, blocks: list[dict] | str): | ||
| self.role = role | ||
| self.content = blocks | ||
|
|
||
| def get_content_blocks(self) -> list[dict]: | ||
| if isinstance(self.content, list): | ||
| return list(self.content) | ||
| if not self.content: | ||
| return [] | ||
| return [{"type": "text", "text": self.content}] | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def anyio_backend() -> str: | ||
| return "asyncio" | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def model_factory_module(monkeypatch: pytest.MonkeyPatch): | ||
| for name in list(sys.modules): | ||
| if name.startswith("copaw.agents.model_factory"): | ||
| sys.modules.pop(name, None) | ||
|
|
||
| copaw_pkg = types.ModuleType("copaw") | ||
| copaw_pkg.__path__ = [] | ||
| agents_pkg = types.ModuleType("copaw.agents") | ||
| agents_pkg.__path__ = [] | ||
| utils_pkg = types.ModuleType("copaw.agents.utils") | ||
| utils_pkg.__path__ = [] | ||
| providers_pkg = types.ModuleType("copaw.providers") | ||
| providers_pkg.__path__ = [] | ||
|
|
||
| monkeypatch.setitem(sys.modules, "copaw", copaw_pkg) | ||
| monkeypatch.setitem(sys.modules, "copaw.agents", agents_pkg) | ||
| monkeypatch.setitem(sys.modules, "copaw.agents.utils", utils_pkg) | ||
| monkeypatch.setitem(sys.modules, "copaw.providers", providers_pkg) | ||
|
|
||
| tool_utils = types.ModuleType("copaw.agents.utils.tool_message_utils") | ||
| tool_utils._sanitize_tool_messages = lambda msgs: msgs | ||
| monkeypatch.setitem( | ||
| sys.modules, | ||
| "copaw.agents.utils.tool_message_utils", | ||
| tool_utils, | ||
| ) | ||
|
|
||
| providers_pkg.ProviderManager = object | ||
| retry_model = types.ModuleType("copaw.providers.retry_chat_model") | ||
| retry_model.RetryChatModel = object | ||
| monkeypatch.setitem( | ||
| sys.modules, | ||
| "copaw.providers.retry_chat_model", | ||
| retry_model, | ||
| ) | ||
|
|
||
| token_usage = types.ModuleType("copaw.token_usage") | ||
| token_usage.TokenRecordingModelWrapper = object | ||
| monkeypatch.setitem(sys.modules, "copaw.token_usage", token_usage) | ||
|
|
||
| agentscope = types.ModuleType("agentscope") | ||
| agentscope.__version__ = "1.0.16" | ||
| formatter_mod = types.ModuleType("agentscope.formatter") | ||
| model_mod = types.ModuleType("agentscope.model") | ||
| message_mod = types.ModuleType("agentscope.message") | ||
|
|
||
| class FormatterBase: | ||
| pass | ||
|
|
||
| class OpenAIChatFormatter(FormatterBase): | ||
| async def format(self, msgs, **kwargs): | ||
| return await self._format(msgs, **kwargs) | ||
|
|
||
| class ChatModelBase: | ||
| pass | ||
|
|
||
| class OpenAIChatModel: | ||
| pass | ||
|
|
||
| class Msg: | ||
| pass | ||
|
|
||
| formatter_mod.FormatterBase = FormatterBase | ||
| formatter_mod.OpenAIChatFormatter = OpenAIChatFormatter | ||
| model_mod.ChatModelBase = ChatModelBase | ||
| model_mod.OpenAIChatModel = OpenAIChatModel | ||
| message_mod.Msg = Msg | ||
|
|
||
| monkeypatch.setitem(sys.modules, "agentscope", agentscope) | ||
| monkeypatch.setitem(sys.modules, "agentscope.formatter", formatter_mod) | ||
| monkeypatch.setitem(sys.modules, "agentscope.model", model_mod) | ||
| monkeypatch.setitem(sys.modules, "agentscope.message", message_mod) | ||
|
|
||
| spec = importlib.util.spec_from_file_location( | ||
| "copaw.agents.model_factory", | ||
| MODEL_FACTORY_PATH, | ||
| ) | ||
| assert spec and spec.loader | ||
| module = importlib.util.module_from_spec(spec) | ||
| monkeypatch.setitem(sys.modules, "copaw.agents.model_factory", module) | ||
| spec.loader.exec_module(module) | ||
| return module | ||
|
|
||
|
|
||
| @pytest.mark.anyio | ||
| async def test_reasoning_content_skips_dropped_thinking_only_messages( | ||
| model_factory_module, | ||
| caplog: pytest.LogCaptureFixture, | ||
| ): | ||
| class BaseFormatter: | ||
| async def _format(self, msgs): | ||
| messages = [] | ||
| for msg in msgs: | ||
| blocks = msg.get_content_blocks() | ||
| text = " ".join( | ||
| block.get("text", "") | ||
| for block in blocks | ||
| if block.get("type") == "text" and block.get("text") | ||
| ) | ||
| if msg.role == "assistant" and text: | ||
| messages.append({"role": "assistant", "content": text}) | ||
| return messages | ||
|
|
||
| formatter = model_factory_module._create_file_block_support_formatter( | ||
| BaseFormatter | ||
| )() | ||
|
|
||
| messages = [ | ||
| FakeMsg("assistant", [{"type": "thinking", "thinking": "discard me"}]), | ||
| FakeMsg( | ||
| "assistant", | ||
| [ | ||
| {"type": "thinking", "thinking": "keep me"}, | ||
| {"type": "text", "text": "visible answer"}, | ||
| ], | ||
| ), | ||
| ] | ||
|
|
||
| with caplog.at_level("WARNING"): | ||
| formatted = await formatter._format(messages) | ||
|
|
||
| assert formatted == [ | ||
| { | ||
| "role": "assistant", | ||
| "content": "visible answer", | ||
| "reasoning_content": "keep me", | ||
| } | ||
| ] | ||
| assert "Assistant message count mismatch after formatting" not in caplog.text | ||
|
|
||
|
|
||
| @pytest.mark.anyio | ||
| async def test_reasoning_content_stays_aligned_with_surviving_assistant_messages( | ||
| model_factory_module, | ||
| ): | ||
| class BaseFormatter: | ||
| async def _format(self, msgs): | ||
| messages = [] | ||
| for msg in msgs: | ||
| blocks = msg.get_content_blocks() | ||
| text = " ".join( | ||
| block.get("text", "") | ||
| for block in blocks | ||
| if block.get("type") == "text" and block.get("text") | ||
| ) | ||
| if msg.role == "assistant" and text: | ||
| messages.append({"role": "assistant", "content": text}) | ||
| return messages | ||
|
|
||
| formatter = model_factory_module._create_file_block_support_formatter( | ||
| BaseFormatter | ||
| )() | ||
|
|
||
| formatted = await formatter._format( | ||
| [ | ||
| FakeMsg("assistant", [{"type": "text", "text": "plain"}]), | ||
| FakeMsg( | ||
| "assistant", | ||
| [ | ||
| {"type": "thinking", "thinking": "reasoned"}, | ||
| {"type": "text", "text": "final"}, | ||
| ], | ||
| ), | ||
| ] | ||
| ) | ||
|
|
||
| assert formatted == [ | ||
| {"role": "assistant", "content": "plain"}, | ||
| { | ||
| "role": "assistant", | ||
| "content": "final", | ||
| "reasoning_content": "reasoned", | ||
| }, | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic within this
forloop can be simplified for improved readability and maintainability. By restructuring the conditions, you can make the code more concise while preserving the existing functionality.