diff --git a/service/app/agents/graph_builder.py b/service/app/agents/graph_builder.py index d5234917..c0726c15 100644 --- a/service/app/agents/graph_builder.py +++ b/service/app/agents/graph_builder.py @@ -334,15 +334,21 @@ async def _build_llm_node(self, config: GraphNodeConfig) -> NodeFunction: async def llm_node(state: StateDict | BaseModel) -> StateDict: logger.info(f"[LLM Node: {config.id}] Starting execution") - # Convert state to dict (handles both dict and Pydantic BaseModel) + # Get messages BEFORE converting state to dict to preserve BaseMessage types + # model_dump() loses tool_call_id and other message-specific fields + if isinstance(state, BaseModel): + messages: list[BaseMessage] = list(getattr(state, "messages", [])) + else: + messages = list(state.get("messages", [])) + + # Convert state to dict for template rendering (but we already have messages) state_dict = self._state_to_dict(state) - messages: list[BaseMessage] = list(state_dict.get("messages", [])) # Render prompt template prompt = self._render_template(llm_config.prompt_template, state_dict) # Build messages for LLM - llm_messages = messages + [HumanMessage(content=prompt)] + llm_messages = list(messages) + [HumanMessage(content=prompt)] # Invoke LLM (using pre-created configured_llm) response = await configured_llm.ainvoke(llm_messages) diff --git a/service/app/core/chat/history.py b/service/app/core/chat/history.py index 88906d50..25b4c11d 100644 --- a/service/app/core/chat/history.py +++ b/service/app/core/chat/history.py @@ -63,8 +63,11 @@ async def load_conversation_history(db: AsyncSession, topic: "TopicModel") -> li history.append(tool_messages[0]) # Skip unknown roles - logger.info(f"Length of history: {len(history)}") - return history + # Validate and filter messages before returning + validated_history = _validate_and_filter_messages(history) + + logger.info(f"Loaded {len(history)} messages, {len(validated_history)} after validation") + return validated_history except Exception as e: logger.warning(f"Failed to load DB chat history for topic {getattr(topic, 'id', None)}: {e}") @@ -180,9 +183,16 @@ def _build_tool_messages( return None, num_tool_calls + 1 elif formatted_content.get("event") == ChatEventType.TOOL_CALL_RESPONSE: + tool_call_id = formatted_content.get("toolCallId") + + # Validate tool_call_id - must be a non-empty string for LangChain + if not tool_call_id or not isinstance(tool_call_id, str): + logger.warning(f"Skipping tool response with invalid tool_call_id: {tool_call_id!r}") + return None, num_tool_calls + message = ToolMessage( - content=formatted_content["result"], - tool_call_id=formatted_content["toolCallId"], + content=formatted_content.get("result", ""), + tool_call_id=tool_call_id, ) return message, num_tool_calls - 1 @@ -190,3 +200,57 @@ def _build_tool_messages( logger.warning(f"Failed to parse tool message content: {e}") return None, num_tool_calls + + +def _validate_and_filter_messages(messages: list[BaseMessage]) -> list[BaseMessage]: + """ + Validate and filter messages to ensure LangChain compatibility. + + This function: + 1. Removes ToolMessages with invalid tool_call_id + 2. Removes orphaned ToolMessages without matching AIMessage tool_calls + 3. Logs warnings for filtered messages + + Args: + messages: List of LangChain messages loaded from history + + Returns: + Filtered list of valid messages + """ + # Collect all valid tool_call_ids from AIMessages + valid_tool_call_ids: set[str] = set() + for msg in messages: + if isinstance(msg, AIMessage) and hasattr(msg, "tool_calls") and msg.tool_calls: + for tc in msg.tool_calls: + tc_id = tc.get("id") + if tc_id: + valid_tool_call_ids.add(tc_id) + + filtered: list[BaseMessage] = [] + skipped_count = 0 + + for msg in messages: + if isinstance(msg, ToolMessage): + tool_call_id = getattr(msg, "tool_call_id", None) + + # Check 1: tool_call_id must be a non-empty string + if not tool_call_id or not isinstance(tool_call_id, str): + logger.warning(f"Filtering out ToolMessage with invalid tool_call_id: {tool_call_id!r}") + skipped_count += 1 + continue + + # Check 2: tool_call_id must have a matching AIMessage tool_call + if tool_call_id not in valid_tool_call_ids: + logger.warning( + f"Filtering out orphaned ToolMessage: tool_call_id={tool_call_id} " + f"not found in any AIMessage.tool_calls" + ) + skipped_count += 1 + continue + + filtered.append(msg) + + if skipped_count > 0: + logger.info(f"Filtered {skipped_count} invalid/orphaned tool messages from history") + + return filtered diff --git a/service/app/tasks/chat.py b/service/app/tasks/chat.py index c19208c0..d6c2993a 100644 --- a/service/app/tasks/chat.py +++ b/service/app/tasks/chat.py @@ -334,14 +334,24 @@ async def _process_chat_message_async( # Persist tool call response try: resp = stream_event["data"] + tool_call_id = resp.get("toolCallId") + + # Only persist if toolCallId is valid - skip otherwise + if not tool_call_id or not isinstance(tool_call_id, str): + logger.warning( + f"Skipping persistence of tool response with invalid toolCallId: {tool_call_id!r}" + ) + await publisher.publish(json.dumps(stream_event)) + continue # Skip to next event, but still publish to frontend + tool_message = MessageCreate( role="tool", content=json.dumps( { "event": ChatEventType.TOOL_CALL_RESPONSE, - "toolCallId": resp.get("toolCallId"), + "toolCallId": tool_call_id, "status": resp.get("status"), - "result": resp.get("result"), + "result": resp.get("result", ""), "error": resp.get("error"), } ),