diff --git a/src/copaw/agents/model_factory.py b/src/copaw/agents/model_factory.py index 609b99f90..536cd24c2 100644 --- a/src/copaw/agents/model_factory.py +++ b/src/copaw/agents/model_factory.py @@ -11,7 +11,7 @@ import logging -from typing import Optional, Sequence, Tuple, Type, Any +from typing import List, Sequence, Tuple, Type, Any, Union, Optional from agentscope.formatter import FormatterBase, OpenAIChatFormatter from agentscope.model import ChatModelBase, OpenAIChatModel @@ -77,6 +77,7 @@ def _get_formatter_for_chat_model( ) +# pylint: disable-next=too-many-statements def _create_file_block_support_formatter( base_formatter_class: Type[FormatterBase], ) -> Type[FormatterBase]: @@ -136,32 +137,44 @@ async def _format(self, msgs): tc["extra_content"] = ec if reasoning_contents: - in_assistant = [m for m in msgs if m.role == "assistant"] + # Build a list of reasoning values aligned with surviving + # assistant messages. The parent formatter drops + # thinking-only messages (no content/tool_calls), so we + # predict survivors and collect reasoning only for those. + aligned_reasoning = [] + for m in (msg for msg in msgs if msg.role == "assistant"): + is_thinking_only = ( + isinstance(m.content, list) + and m.content + and all(b.get("type") == "thinking" for b in m.content) + ) + if not is_thinking_only: + aligned_reasoning.append( + reasoning_contents.get(id(m)), + ) + out_assistant = [ m for m in messages if m.get("role") == "assistant" ] - if len(in_assistant) != len(out_assistant): + + if len(aligned_reasoning) != len(out_assistant): logger.warning( "Assistant message count mismatch after formatting " - "(%d before, %d after). " + "(%d expected survivors, %d actual). " "Skipping reasoning_content injection.", - len(in_assistant), + len(aligned_reasoning), len(out_assistant), ) else: - for in_msg, out_msg in zip( - in_assistant, - out_assistant, - ): - reasoning = reasoning_contents.get(id(in_msg)) - if reasoning: - out_msg["reasoning_content"] = reasoning + for i, out_msg in enumerate(out_assistant): + if aligned_reasoning[i]: + out_msg["reasoning_content"] = aligned_reasoning[i] return _strip_top_level_message_name(messages) @staticmethod def convert_tool_result_to_string( - output: str | list[dict], + output: Union[str, List[dict]], ) -> tuple[str, Sequence[Tuple[str, dict]]]: """Extend parent class to support file blocks.