Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions service/app/agents/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,21 @@ async def _build_llm_node(self, config: GraphNodeConfig) -> NodeFunction:
async def llm_node(state: StateDict | BaseModel) -> StateDict:
logger.info(f"[LLM Node: {config.id}] Starting execution")

# Convert state to dict (handles both dict and Pydantic BaseModel)
# Get messages BEFORE converting state to dict to preserve BaseMessage types
# model_dump() loses tool_call_id and other message-specific fields
if isinstance(state, BaseModel):
messages: list[BaseMessage] = list(getattr(state, "messages", []))
Comment on lines +339 to +340
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: 在调用 list() 之前要防止 state.messages 为 None,以避免 TypeError。

如果 state.messages / state.get("messages") 可能为 Nonelist(None) 会抛出 TypeError。可以考虑先做归一化处理,例如:

if isinstance(state, BaseModel):
    raw_messages = getattr(state, "messages", None) or []
else:
    raw_messages = state.get("messages") or []
messages: list[BaseMessage] = list(raw_messages)

这样既保持了当前行为(拷贝并保留 BaseMessage 实例),又能在 messages 缺失或为 None 时避免运行时错误。

Original comment in English

issue: Guard against state.messages being None before calling list() to avoid TypeError.

If state.messages/state.get("messages") can be None, list(None) will raise a TypeError. Consider normalizing first, e.g.:

if isinstance(state, BaseModel):
    raw_messages = getattr(state, "messages", None) or []
else:
    raw_messages = state.get("messages") or []
messages: list[BaseMessage] = list(raw_messages)

This keeps the current behavior (copying and preserving BaseMessage instances) while avoiding runtime errors when messages is missing or None.

else:
messages = list(state.get("messages", []))

# Convert state to dict for template rendering (but we already have messages)
state_dict = self._state_to_dict(state)
messages: list[BaseMessage] = list(state_dict.get("messages", []))

# Render prompt template
prompt = self._render_template(llm_config.prompt_template, state_dict)

# Build messages for LLM
llm_messages = messages + [HumanMessage(content=prompt)]
llm_messages = list(messages) + [HumanMessage(content=prompt)]
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code creates a list copy three times: first at line 340/342 with list(...), then again at line 351 when building llm_messages. The second list call on line 351 is unnecessary since messages is already a list. Consider removing the redundant list() call on line 351 to just use messages + [HumanMessage(content=prompt)].

Suggested change
llm_messages = list(messages) + [HumanMessage(content=prompt)]
llm_messages = messages + [HumanMessage(content=prompt)]

Copilot uses AI. Check for mistakes.

# Invoke LLM (using pre-created configured_llm)
response = await configured_llm.ainvoke(llm_messages)
Expand Down
72 changes: 68 additions & 4 deletions service/app/core/chat/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ async def load_conversation_history(db: AsyncSession, topic: "TopicModel") -> li
history.append(tool_messages[0])
# Skip unknown roles

logger.info(f"Length of history: {len(history)}")
return history
# Validate and filter messages before returning
validated_history = _validate_and_filter_messages(history)

logger.info(f"Loaded {len(history)} messages, {len(validated_history)} after validation")
return validated_history

except Exception as e:
logger.warning(f"Failed to load DB chat history for topic {getattr(topic, 'id', None)}: {e}")
Expand Down Expand Up @@ -180,13 +183,74 @@ def _build_tool_messages(
return None, num_tool_calls + 1

elif formatted_content.get("event") == ChatEventType.TOOL_CALL_RESPONSE:
tool_call_id = formatted_content.get("toolCallId")

# Validate tool_call_id - must be a non-empty string for LangChain
if not tool_call_id or not isinstance(tool_call_id, str):
logger.warning(f"Skipping tool response with invalid tool_call_id: {tool_call_id!r}")
return None, num_tool_calls

message = ToolMessage(
content=formatted_content["result"],
tool_call_id=formatted_content["toolCallId"],
content=formatted_content.get("result", ""),
tool_call_id=tool_call_id,
)
return message, num_tool_calls - 1

except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Failed to parse tool message content: {e}")

return None, num_tool_calls


def _validate_and_filter_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
"""
Validate and filter messages to ensure LangChain compatibility.

This function:
1. Removes ToolMessages with invalid tool_call_id
2. Removes orphaned ToolMessages without matching AIMessage tool_calls
3. Logs warnings for filtered messages

Args:
messages: List of LangChain messages loaded from history

Returns:
Filtered list of valid messages
"""
# Collect all valid tool_call_ids from AIMessages
valid_tool_call_ids: set[str] = set()
for msg in messages:
if isinstance(msg, AIMessage) and hasattr(msg, "tool_calls") and msg.tool_calls:
for tc in msg.tool_calls:
tc_id = tc.get("id")
if tc_id:
valid_tool_call_ids.add(tc_id)

filtered: list[BaseMessage] = []
skipped_count = 0

for msg in messages:
if isinstance(msg, ToolMessage):
tool_call_id = getattr(msg, "tool_call_id", None)

# Check 1: tool_call_id must be a non-empty string
if not tool_call_id or not isinstance(tool_call_id, str):
logger.warning(f"Filtering out ToolMessage with invalid tool_call_id: {tool_call_id!r}")
skipped_count += 1
continue

# Check 2: tool_call_id must have a matching AIMessage tool_call
if tool_call_id not in valid_tool_call_ids:
logger.warning(
f"Filtering out orphaned ToolMessage: tool_call_id={tool_call_id} "
f"not found in any AIMessage.tool_calls"
)
skipped_count += 1
continue

filtered.append(msg)

if skipped_count > 0:
logger.info(f"Filtered {skipped_count} invalid/orphaned tool messages from history")

return filtered
14 changes: 12 additions & 2 deletions service/app/tasks/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,14 +334,24 @@ async def _process_chat_message_async(
# Persist tool call response
try:
resp = stream_event["data"]
tool_call_id = resp.get("toolCallId")

# Only persist if toolCallId is valid - skip otherwise
if not tool_call_id or not isinstance(tool_call_id, str):
logger.warning(
f"Skipping persistence of tool response with invalid toolCallId: {tool_call_id!r}"
)
await publisher.publish(json.dumps(stream_event))
continue # Skip to next event, but still publish to frontend

tool_message = MessageCreate(
role="tool",
content=json.dumps(
{
"event": ChatEventType.TOOL_CALL_RESPONSE,
"toolCallId": resp.get("toolCallId"),
"toolCallId": tool_call_id,
"status": resp.get("status"),
"result": resp.get("result"),
"result": resp.get("result", ""),
"error": resp.get("error"),
}
),
Expand Down
Loading