diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index e59dbcf2ce..4a42689849 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -146,13 +146,6 @@ def _normalize_custom_agent_name(raw_value: str) -> str: return normalized -def _strip_loop_warning_text(text: str) -> str: - """Remove middleware-authored loop warning lines from display text.""" - if "[LOOP DETECTED]" not in text: - return text - return "\n".join(line for line in text.splitlines() if "[LOOP DETECTED]" not in line).strip() - - def _extract_response_text(result: dict | list) -> str: """Extract the last AI message text from a LangGraph runs.wait result. @@ -162,7 +155,6 @@ def _extract_response_text(result: dict | list) -> str: Handles special cases: - Regular AI text responses - Clarification interrupts (``ask_clarification`` tool messages) - - Strips loop-detection warnings attached to tool-call AI messages """ if isinstance(result, list): messages = result @@ -192,12 +184,7 @@ def _extract_response_text(result: dict | list) -> str: # Regular AI message with text content if msg_type == "ai": content = msg.get("content", "") - has_tool_calls = bool(msg.get("tool_calls")) if isinstance(content, str) and content: - if has_tool_calls: - content = _strip_loop_warning_text(content) - if not content: - continue return content # content can be a list of content blocks if isinstance(content, list): @@ -208,8 +195,6 @@ def _extract_response_text(result: dict | list) -> str: elif isinstance(block, str): parts.append(block) text = "".join(parts) - if has_tool_calls: - text = _strip_loop_warning_text(text) if text: return text return "" diff --git a/backend/docs/middleware-execution-flow.md b/backend/docs/middleware-execution-flow.md index 922cc96402..29fa9328a0 100644 --- a/backend/docs/middleware-execution-flow.md +++ b/backend/docs/middleware-execution-flow.md @@ -4,22 +4,22 @@ `create_deerflow_agent` 通过 `RuntimeFeatures` 组装的完整 middleware 链(默认全开时): -| # | Middleware | `before_agent` | `before_model` | `after_model` | `after_agent` | `wrap_tool_call` | 主 Agent | Subagent | 来源 | -|---|-----------|:-:|:-:|:-:|:-:|:-:|:-:|:-:|------| -| 0 | ThreadDataMiddleware | ✓ | | | | | ✓ | ✓ | `sandbox` | -| 1 | UploadsMiddleware | ✓ | | | | | ✓ | ✗ | `sandbox` | -| 2 | SandboxMiddleware | ✓ | | | ✓ | | ✓ | ✓ | `sandbox` | -| 3 | DanglingToolCallMiddleware | | | ✓ | | | ✓ | ✗ | 始终开启 | -| 4 | GuardrailMiddleware | | | | | ✓ | ✓ | ✓ | *Phase 2 纳入* | -| 5 | ToolErrorHandlingMiddleware | | | | | ✓ | ✓ | ✓ | 始终开启 | -| 6 | SummarizationMiddleware | | | ✓ | | | ✓ | ✗ | `summarization` | -| 7 | TodoMiddleware | | | ✓ | | | ✓ | ✗ | `plan_mode` 参数 | -| 8 | TitleMiddleware | | | ✓ | | | ✓ | ✗ | `auto_title` | -| 9 | MemoryMiddleware | | | | ✓ | | ✓ | ✗ | `memory` | -| 10 | ViewImageMiddleware | | ✓ | | | | ✓ | ✗ | `vision` | -| 11 | SubagentLimitMiddleware | | | ✓ | | | ✓ | ✗ | `subagent` | -| 12 | LoopDetectionMiddleware | | | ✓ | | | ✓ | ✗ | 始终开启 | -| 13 | ClarificationMiddleware | | | ✓ | | | ✓ | ✗ | 始终最后 | +| # | Middleware | `before_agent` | `before_model` | `after_model` | `after_agent` | `wrap_model_call` | `wrap_tool_call` | 主 Agent | Subagent | 来源 | +|---|-----------|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|------| +| 0 | ThreadDataMiddleware | ✓ | | | | | | ✓ | ✓ | `sandbox` | +| 1 | UploadsMiddleware | ✓ | | | | | | ✓ | ✗ | `sandbox` | +| 2 | SandboxMiddleware | ✓ | | | ✓ | | | ✓ | ✓ | `sandbox` | +| 3 | DanglingToolCallMiddleware | | | | | ✓ | | ✓ | ✗ | 始终开启 | +| 4 | GuardrailMiddleware | | | | | | ✓ | ✓ | ✓ | *Phase 2 纳入* | +| 5 | ToolErrorHandlingMiddleware | | | | | | ✓ | ✓ | ✓ | 始终开启 | +| 6 | SummarizationMiddleware | | ✓ | | | | | ✓ | ✗ | `summarization` | +| 7 | TodoMiddleware | | ✓ | ✓ | | | | ✓ | ✗ | `plan_mode` 参数 | +| 8 | TitleMiddleware | | | ✓ | | | | ✓ | ✗ | `auto_title` | +| 9 | MemoryMiddleware | | | | ✓ | | | ✓ | ✗ | `memory` | +| 10 | ViewImageMiddleware | | ✓ | | | | | ✓ | ✗ | `vision` | +| 11 | SubagentLimitMiddleware | | | ✓ | | | | ✓ | ✗ | `subagent` | +| 12 | LoopDetectionMiddleware | ✓ | | ✓ | ✓ | ✓ | | ✓ | ✗ | 始终开启 | +| 13 | ClarificationMiddleware | | | | | | ✓ | ✓ | ✗ | 始终最后 | 主 agent **14 个** middleware(`make_lead_agent`),subagent **4 个**(ThreadData、Sandbox、Guardrail、ToolErrorHandling)。`create_deerflow_agent` Phase 1 实现 **13 个**(Guardrail 仅支持自定义实例,无内置默认)。 @@ -35,7 +35,7 @@ graph TB subgraph BA ["before_agent 正序 0→N"] direction TB - TD["[0] ThreadData
创建线程目录"] --> UL["[1] Uploads
扫描上传文件"] --> SB["[2] Sandbox
获取沙箱"] + TD["[0] ThreadData
创建线程目录"] --> UL["[1] Uploads
扫描上传文件"] --> SB["[2] Sandbox
获取沙箱"] --> LD_BA["[12] LoopDetection
清理 stale warning"] end subgraph BM ["before_model 正序 0→N"] @@ -43,34 +43,42 @@ graph TB VI["[10] ViewImage
注入图片 base64"] end - SB --> VI - VI --> M["MODEL"] + subgraph WM ["wrap_model_call"] + direction TB + DTC_WM["[3] DanglingToolCall
补悬空 ToolMessage"] --> LD_WM["[12] LoopDetection
注入当前 run warning"] + end + + LD_BA --> VI + VI --> DTC_WM + LD_WM --> M["MODEL"] subgraph AM ["after_model 反序 N→0"] direction TB - CL["[13] Clarification
拦截 ask_clarification"] --> LD["[12] LoopDetection
检测循环"] --> SL["[11] SubagentLimit
截断多余 task"] --> TI["[8] Title
生成标题"] --> SM["[6] Summarization
上下文压缩"] --> DTC["[3] DanglingToolCall
补缺失 ToolMessage"] + LD["[12] LoopDetection
检测循环/排队 warning"] --> SL["[11] SubagentLimit
截断多余 task"] --> TI["[8] Title
生成标题"] end - M --> CL + M --> LD subgraph AA ["after_agent 反序 N→0"] direction TB - SBR["[2] Sandbox
释放沙箱"] --> MEM["[9] Memory
入队记忆"] + LD_CLEAN["[12] LoopDetection
清理 pending warning"] --> MEM["[9] Memory
入队记忆"] --> SBR["[2] Sandbox
释放沙箱"] end - DTC --> SBR - MEM --> END(["response"]) + TI --> LD_CLEAN + SBR --> END(["response"]) classDef beforeNode fill:#a0a8b5,stroke:#636b7a,color:#2d3239 classDef modelNode fill:#b5a8a0,stroke:#7a6b63,color:#2d3239 + classDef wrapModelNode fill:#a8a0b5,stroke:#6b637a,color:#2d3239 classDef afterModelNode fill:#b5a0a8,stroke:#7a636b,color:#2d3239 classDef afterAgentNode fill:#a0b5a8,stroke:#637a6b,color:#2d3239 classDef terminalNode fill:#a8b5a0,stroke:#6b7a63,color:#2d3239 - class TD,UL,SB,VI beforeNode + class TD,UL,SB,LD_BA,VI beforeNode + class DTC_WM,LD_WM wrapModelNode class M modelNode - class CL,LD,SL,TI,SM,DTC afterModelNode - class SBR,MEM afterAgentNode + class LD,SL,TI afterModelNode + class LD_CLEAN,SBR,MEM afterAgentNode class START,END terminalNode ``` @@ -82,13 +90,12 @@ sequenceDiagram participant TD as ThreadDataMiddleware participant UL as UploadsMiddleware participant SB as SandboxMiddleware + participant LD as LoopDetectionMiddleware participant VI as ViewImageMiddleware + participant DTC as DanglingToolCallMiddleware participant M as MODEL - participant CL as ClarificationMiddleware participant SL as SubagentLimitMiddleware participant TI as TitleMiddleware - participant SM as SummarizationMiddleware - participant DTC as DanglingToolCallMiddleware participant MEM as MemoryMiddleware U ->> TD: invoke @@ -103,19 +110,26 @@ sequenceDiagram activate SB Note right of SB: before_agent 获取沙箱 - SB ->> VI: before_model + SB ->> LD: before_agent + activate LD + Note right of LD: before_agent 清理同 thread 旧 run 的 pending warning + LD ->> VI: before_model activate VI Note right of VI: before_model 注入图片 base64 - VI ->> M: messages + tools + VI ->> DTC: wrap_model_call + activate DTC + Note right of DTC: wrap_model_call 补悬空 ToolMessage + DTC ->> LD: wrap_model_call + Note right of LD: wrap_model_call drain 当前 run warning 并追加到末尾 + LD ->> M: messages + tools activate M - M -->> CL: AI response + M -->> LD: AI response deactivate M - activate CL - Note right of CL: after_model 拦截 ask_clarification - CL -->> SL: after_model - deactivate CL + Note right of LD: after_model 检测循环;warning 入队,hard-stop 清 tool_calls + LD -->> SL: after_model + deactivate LD activate SL Note right of SL: after_model 截断多余 task @@ -124,22 +138,18 @@ sequenceDiagram activate TI Note right of TI: after_model 生成标题 - TI -->> SM: after_model + TI -->> DTC: done deactivate TI - activate SM - Note right of SM: after_model 上下文压缩 - SM -->> DTC: after_model - deactivate SM - - activate DTC - Note right of DTC: after_model 补缺失 ToolMessage - DTC -->> VI: done deactivate DTC VI -->> SB: done deactivate VI + Note right of LD: after_agent 清理当前 run 未消费 warning + + Note right of MEM: after_agent 入队记忆 + Note right of SB: after_agent 释放沙箱 SB -->> UL: done deactivate SB @@ -147,8 +157,6 @@ sequenceDiagram UL -->> TD: done deactivate UL - Note right of MEM: after_agent 入队记忆 - TD -->> U: response deactivate TD ``` @@ -224,12 +232,12 @@ sequenceDiagram participant TD as ThreadData participant UL as Uploads participant SB as Sandbox + participant LD as LoopDetection participant VI as ViewImage + participant DTC as DanglingToolCall participant M as MODEL - participant CL as Clarification participant SL as SubagentLimit participant TI as Title - participant SM as Summarization participant MEM as Memory U ->> TD: invoke @@ -238,34 +246,40 @@ sequenceDiagram Note right of UL: before_agent 扫描文件 UL ->> SB: . Note right of SB: before_agent 获取沙箱 + SB ->> LD: . + Note right of LD: before_agent 清理 stale pending warning loop 每轮对话(tool call 循环) SB ->> VI: . Note right of VI: before_model 注入图片 - VI ->> M: messages + tools - M -->> CL: AI response - Note right of CL: after_model 拦截 ask_clarification - CL -->> SL: . + VI ->> DTC: . + Note right of DTC: wrap_model_call 补悬空工具结果 + DTC ->> LD: . + Note right of LD: wrap_model_call 注入当前 run warning + LD ->> M: messages + tools + M -->> LD: AI response + Note right of LD: after_model 检测循环/排队 warning + LD -->> SL: . Note right of SL: after_model 截断多余 task SL -->> TI: . Note right of TI: after_model 生成标题 - TI -->> SM: . - Note right of SM: after_model 上下文压缩 end - Note right of SB: after_agent 释放沙箱 - SB -->> MEM: . + Note right of LD: after_agent 清理当前 run pending warning + LD -->> MEM: . Note right of MEM: after_agent 入队记忆 - MEM -->> U: response + MEM -->> SB: . + Note right of SB: after_agent 释放沙箱 + SB -->> U: response ``` > [!warning] 不是洋葱 -> 14 个 middleware 中只有 SandboxMiddleware 有 before/after 对称(获取/释放)。其余都是单向的:要么只在 `before_*` 做事,要么只在 `after_*` 做事。`before_agent` / `after_agent` 只跑一次,`before_model` / `after_model` 每轮循环都跑。 +> 大部分 middleware 只用一个阶段。SandboxMiddleware 使用 `before_agent`/`after_agent` 做资源获取/释放;LoopDetectionMiddleware 也使用这两个钩子,但用途是清理 run-scoped pending warnings,不是资源生命周期对称。`before_agent` / `after_agent` 只跑一次,`before_model` / `after_model` / `wrap_model_call` 每轮循环都跑。 硬依赖只有 2 处: 1. **ThreadData 在 Sandbox 之前** — sandbox 需要线程目录 -2. **Clarification 在列表最后** — `after_model` 反序时最先执行,第一个拦截 `ask_clarification` +2. **Clarification 在列表最后** — `wrap_tool_call` 处理 `ask_clarification` 时优先拦截,并通过 `Command(goto=END)` 中断执行 ### 结论 @@ -273,19 +287,19 @@ sequenceDiagram |---|---|---| | 每个 middleware | before + after 对称 | 大多只用一个钩子 | | 激活条 | 嵌套(外长内短) | 不嵌套(串行) | -| 反序的意义 | 清理与初始化配对 | 仅影响 after_model 的执行优先级 | +| 反序的意义 | 清理与初始化配对 | 影响 `after_model` / `after_agent` 的执行优先级 | | 典型例子 | Auth: 校验 token / 清理上下文 | ThreadData: 只创建目录,没有清理 | ## 关键设计点 ### ClarificationMiddleware 为什么在列表最后? -位置最后 = `after_model` 最先执行。它需要**第一个**看到 model 输出,检查是否有 `ask_clarification` tool call。如果有,立即中断(`Command(goto=END)`),后续 middleware 的 `after_model` 不再执行。 +位置最后使它在工具调用包装链中优先拦截 `ask_clarification`。如果命中,它返回 `Command(goto=END)`,把格式化后的澄清问题写成 `ToolMessage` 并中断执行。 ### SandboxMiddleware 的对称性 `before_agent`(正序第 3 个)获取沙箱,`after_agent`(反序第 1 个)释放沙箱。外层进入 → 外层退出,天然的洋葱对称。 -### 大部分 middleware 只用一个钩子 +### LoopDetectionMiddleware 为什么同时用多个钩子? -14 个 middleware 中,只有 SandboxMiddleware 同时用了 `before_agent` + `after_agent`(获取/释放)。其余都只在一个阶段执行。洋葱模型的反序特性主要影响 `after_model` 阶段的执行顺序。 +`after_model` 只做检测:重复工具调用达到 warning 阈值时,把 warning 放入 `(thread_id, run_id)` 作用域的 pending 队列。真正注入发生在下一次 `wrap_model_call`:此时上一轮 `AIMessage(tool_calls)` 对应的 `ToolMessage` 已经在请求里,warning 追加在末尾,不会破坏 OpenAI/Moonshot 的 tool-call pairing。`before_agent` 清理同一 thread 下旧 run 的残留 warning,`after_agent` 清理当前 run 没被消费的 warning。 diff --git a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py index db83051e90..1f0ef02f0a 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py @@ -6,10 +6,30 @@ Detection strategy: 1. After each model response, hash the tool calls (name + args). 2. Track recent hashes in a sliding window. - 3. If the same hash appears >= warn_threshold times, inject a - "you are repeating yourself — wrap up" system message (once per hash). + 3. If the same hash appears >= warn_threshold times, queue a + "you are repeating yourself — wrap up" warning for the current + thread/run. The warning is **injected at the next model call** (in + ``wrap_model_call``) as a ``HumanMessage`` appended to the message + list, *after* all ToolMessage responses to the previous + AIMessage(tool_calls). 4. If it appears >= hard_limit times, strip all tool_calls from the response so the agent is forced to produce a final text answer. + +Why the warning is injected at ``wrap_model_call`` instead of +``after_model``: + + ``after_model`` fires immediately after the model emits an + ``AIMessage`` that may carry ``tool_calls``. The tools node has not + run yet, so no matching ``ToolMessage`` exists in the history. Any + message we add here lands *between* the assistant's tool_calls and + their responses. OpenAI/Moonshot reject the next request with + ``"tool_call_ids did not have response messages"`` because their + validators require the assistant's tool_calls to be followed + immediately by tool messages. Anthropic also disallows mid-stream + ``SystemMessage``. By deferring the warning to ``wrap_model_call``, + every prior ToolMessage is already present in the request's message + list and the warning is appended at the end — pairing intact, no + ``AIMessage`` semantics are mutated. """ from __future__ import annotations @@ -19,11 +39,14 @@ import logging import threading from collections import OrderedDict, defaultdict +from collections.abc import Awaitable, Callable from copy import deepcopy from typing import TYPE_CHECKING, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware +from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse +from langchain_core.messages import HumanMessage from langgraph.runtime import Runtime if TYPE_CHECKING: @@ -195,6 +218,10 @@ def __init__( self._warned: dict[str, set[str]] = defaultdict(set) self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) self._tool_freq_warned: dict[str, set[str]] = defaultdict(set) + # Per-thread/run queue of warnings to inject at the next model call. + # Populated by ``after_model`` (detection) and drained by + # ``wrap_model_call`` (injection); see module docstring. + self._pending_warnings: dict[tuple[str, str], list[str]] = defaultdict(list) @classmethod def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware: @@ -213,9 +240,20 @@ def _get_thread_id(self, runtime: Runtime) -> str: """Extract thread_id from runtime context for per-thread tracking.""" thread_id = runtime.context.get("thread_id") if runtime.context else None if thread_id: - return thread_id + return str(thread_id) + return "default" + + def _get_run_id(self, runtime: Runtime) -> str: + """Extract run_id from runtime context for per-run warning scoping.""" + run_id = runtime.context.get("run_id") if runtime.context else None + if run_id: + return str(run_id) return "default" + def _pending_key(self, runtime: Runtime) -> tuple[str, str]: + """Return the pending-warning key for the current thread/run.""" + return self._get_thread_id(runtime), self._get_run_id(runtime) + def _evict_if_needed(self) -> None: """Evict least recently used threads if over the limit. @@ -226,6 +264,9 @@ def _evict_if_needed(self) -> None: self._warned.pop(evicted_id, None) self._tool_freq.pop(evicted_id, None) self._tool_freq_warned.pop(evicted_id, None) + for key in list(self._pending_warnings): + if key[0] == evicted_id: + self._pending_warnings.pop(key, None) logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id) def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]: @@ -381,7 +422,10 @@ def _apply(self, state: AgentState, runtime: Runtime) -> dict | None: warning, hard_stop = self._track_and_check(state, runtime) if hard_stop: - # Strip tool_calls from the last AIMessage to force text output + # Strip tool_calls from the last AIMessage to force text output. + # Once tool_calls are stripped, the AIMessage no longer requires + # matching ToolMessage responses, so mutating it in place here + # is safe for OpenAI/Moonshot pairing validators. messages = state.get("messages", []) last_msg = messages[-1] content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG) @@ -389,33 +433,50 @@ def _apply(self, state: AgentState, runtime: Runtime) -> dict | None: return {"messages": [stripped_msg]} if warning: - # WORKAROUND for v2.0-m1 — see #2724. - # - # Append the warning to the AIMessage content instead of - # injecting a separate HumanMessage. Inserting any non-tool - # message between an AIMessage(tool_calls=...) and its - # ToolMessage responses breaks OpenAI/Moonshot strict pairing - # validation ("tool_call_ids did not have response messages") - # because the tools node has not run yet at after_model time. - # tool_calls are preserved so the tools node still executes. - # - # This is a temporary mitigation: mutating an existing - # AIMessage to carry framework-authored text leaks loop-warning - # text into downstream consumers (MemoryMiddleware fact - # extraction, TitleMiddleware, telemetry, model replay) as if - # the model said it. The proper fix is to defer warning - # injection from after_model to wrap_model_call so every prior - # ToolMessage is already in the request — see RFC #2517 (which - # lists "loop intervention does not leave invalid - # tool-call/tool-message state" as acceptance criteria) and - # the prototype on `fix/loop-detection-tool-call-pairing`. - messages = state.get("messages", []) - last_msg = messages[-1] - patched_msg = last_msg.model_copy(update={"content": self._append_text(last_msg.content, warning)}) - return {"messages": [patched_msg]} + # Defer injection to the next model call. We must NOT alter the + # AIMessage(tool_calls=...) here (would put framework words in + # the model's mouth, polluting downstream consumers like + # MemoryMiddleware), nor insert a separate non-tool message + # (would break OpenAI/Moonshot tool-call pairing because the + # tools node has not produced ToolMessage responses yet). The + # warning is delivered via ``wrap_model_call`` below. + pending_key = self._pending_key(runtime) + with self._lock: + self._pending_warnings[pending_key].append(warning) + return None return None + def _clear_other_run_pending_warnings(self, runtime: Runtime) -> None: + """Drop stale pending warnings for previous runs in this thread.""" + thread_id, current_run_id = self._pending_key(runtime) + with self._lock: + for key in list(self._pending_warnings): + if key[0] == thread_id and key[1] != current_run_id: + self._pending_warnings.pop(key, None) + + def _clear_current_run_pending_warnings(self, runtime: Runtime) -> None: + """Drop pending warnings owned by the current thread/run.""" + pending_key = self._pending_key(runtime) + with self._lock: + self._pending_warnings.pop(pending_key, None) + + @staticmethod + def _format_warning_message(warnings: list[str]) -> str: + """Merge pending warnings into one prompt message.""" + deduped = list(dict.fromkeys(warnings)) + return "\n\n".join(deduped) + + @override + def before_agent(self, state: AgentState, runtime: Runtime) -> dict | None: + self._clear_other_run_pending_warnings(runtime) + return None + + @override + async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict | None: + self._clear_other_run_pending_warnings(runtime) + return None + @override def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._apply(state, runtime) @@ -424,6 +485,58 @@ def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._apply(state, runtime) + @override + def after_agent(self, state: AgentState, runtime: Runtime) -> dict | None: + self._clear_current_run_pending_warnings(runtime) + return None + + @override + async def aafter_agent(self, state: AgentState, runtime: Runtime) -> dict | None: + self._clear_current_run_pending_warnings(runtime) + return None + + def _drain_pending_warnings(self, runtime: Runtime) -> list[str]: + """Pop and return all queued warnings for *runtime*'s thread/run.""" + pending_key = self._pending_key(runtime) + with self._lock: + warnings = self._pending_warnings.pop(pending_key, []) + return warnings + + def _augment_request(self, request: ModelRequest) -> ModelRequest: + """Append queued loop warnings (if any) to the outgoing message list. + + The warning is placed *after* every existing message, including the + ToolMessage responses to the previous AIMessage(tool_calls). This + keeps ``assistant tool_calls -> tool_messages`` pairing intact for + OpenAI/Moonshot, avoids the Anthropic mid-stream SystemMessage + restriction (we use HumanMessage), and never mutates an existing + AIMessage. + """ + warnings = self._drain_pending_warnings(request.runtime) + if not warnings: + return request + new_messages = [ + *request.messages, + HumanMessage(content=self._format_warning_message(warnings), name="loop_warning"), + ] + return request.override(messages=new_messages) + + @override + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + return handler(self._augment_request(request)) + + @override + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + return await handler(self._augment_request(request)) + def reset(self, thread_id: str | None = None) -> None: """Clear tracking state. If thread_id given, clear only that thread.""" with self._lock: @@ -432,8 +545,12 @@ def reset(self, thread_id: str | None = None) -> None: self._warned.pop(thread_id, None) self._tool_freq.pop(thread_id, None) self._tool_freq_warned.pop(thread_id, None) + for key in list(self._pending_warnings): + if key[0] == thread_id: + self._pending_warnings.pop(key, None) else: self._history.clear() self._warned.clear() self._tool_freq.clear() self._tool_freq_warned.clear() + self._pending_warnings.clear() diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index d68701c4ef..62f3ce0f04 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -372,37 +372,6 @@ def test_does_not_leak_previous_turn_text(self): # Should return "" (no text in current turn), NOT "Hi there!" from previous turn assert _extract_response_text(result) == "" - def test_does_not_publish_loop_warning_on_tool_calling_ai_message(self): - """Loop-detection warning text on a tool-calling AI message is middleware-authored.""" - from app.channels.manager import _extract_response_text - - result = { - "messages": [ - {"type": "human", "content": "search the repo"}, - { - "type": "ai", - "content": "[LOOP DETECTED] You are repeating the same tool calls.", - "tool_calls": [{"name": "grep", "args": {"pattern": "TODO"}, "id": "call_1"}], - }, - ] - } - assert _extract_response_text(result) == "" - - def test_preserves_visible_text_when_stripping_loop_warning(self): - from app.channels.manager import _extract_response_text - - result = { - "messages": [ - {"type": "human", "content": "prepare the report"}, - { - "type": "ai", - "content": "Here is the report.\n\n[LOOP DETECTED] You are repeating the same tool calls.", - "tool_calls": [{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/report.md"]}, "id": "call_1"}], - }, - ] - } - assert _extract_response_text(result) == "Here is the report." - # --------------------------------------------------------------------------- # ChannelManager tests diff --git a/backend/tests/test_loop_detection_middleware.py b/backend/tests/test_loop_detection_middleware.py index 022afc1171..027fd973e5 100644 --- a/backend/tests/test_loop_detection_middleware.py +++ b/backend/tests/test_loop_detection_middleware.py @@ -3,7 +3,7 @@ import copy from unittest.mock import MagicMock -from langchain_core.messages import AIMessage, SystemMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from deerflow.agents.middlewares.loop_detection_middleware import ( _HARD_STOP_MSG, @@ -12,13 +12,46 @@ ) -def _make_runtime(thread_id="test-thread"): +def _make_runtime(thread_id="test-thread", run_id="test-run"): """Build a minimal Runtime mock with context.""" runtime = MagicMock() - runtime.context = {"thread_id": thread_id} + runtime.context = {"thread_id": thread_id, "run_id": run_id} return runtime +def _pending_key(thread_id="test-thread", run_id="test-run"): + return (thread_id, run_id) + + +def _make_request(messages, runtime): + """Build a minimal ModelRequest stand-in for wrap_model_call tests.""" + request = MagicMock() + request.messages = list(messages) + request.runtime = runtime + request.override = lambda **updates: _override_request(request, updates) + return request + + +def _override_request(request, updates): + """Mimic ModelRequest.override(): return a copy with fields replaced.""" + new = MagicMock() + new.messages = updates.get("messages", request.messages) + new.runtime = updates.get("runtime", request.runtime) + new.override = lambda **u: _override_request(new, u) + return new + + +def _capture_handler(): + """Build a sync handler that records the request it was called with.""" + captured: list = [] + + def handler(req): + captured.append(req) + return MagicMock() + + return captured, handler + + def _make_state(tool_calls=None, content=""): """Build a minimal AgentState dict with an AIMessage. @@ -138,7 +171,15 @@ def test_below_threshold_returns_none(self): result = mw._apply(_make_state(tool_calls=call), runtime) assert result is None - def test_warn_at_threshold(self): + def test_warn_at_threshold_queues_but_does_not_mutate_state(self): + """At warn threshold, ``after_model`` enqueues but returns None. + + Detection observes the just-emitted AIMessage(tool_calls=...). The + tools node hasn't run yet, so injecting any non-tool message here + would split the assistant's tool_calls from their ToolMessage + responses and break OpenAI/Moonshot pairing. The warning is + delivered later from ``wrap_model_call``. + """ mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=5) runtime = _make_runtime() call = [_bash_call("ls")] @@ -146,44 +187,150 @@ def test_warn_at_threshold(self): for _ in range(2): mw._apply(_make_state(tool_calls=call), runtime) - # Third identical call triggers warning. The warning is appended to - # the AIMessage content (tool_calls preserved) — never inserted as a - # separate HumanMessage between the AIMessage(tool_calls) and its - # ToolMessage responses, which would break OpenAI/Moonshot strict - # tool-call pairing validation. + # Third identical call triggers warning detection. result = mw._apply(_make_state(tool_calls=call), runtime) - assert result is not None - msgs = result["messages"] - assert len(msgs) == 1 - assert isinstance(msgs[0], AIMessage) - assert len(msgs[0].tool_calls) == len(call) - assert msgs[0].tool_calls[0]["id"] == call[0]["id"] - assert "LOOP DETECTED" in msgs[0].content - - def test_warn_does_not_break_tool_call_pairing(self): - """Regression: the warn branch must NOT inject a non-tool message - after an AIMessage(tool_calls=...). Moonshot/OpenAI reject the next - request with 'tool_call_ids did not have response messages' if any - non-tool message is wedged between the AIMessage and its ToolMessage - responses. See #2029. + # Detection must not mutate state — the AIMessage with tool_calls is + # left untouched so the tools node runs normally. + assert result is None + # ...but a warning is queued for the next model call. + assert mw._pending_warnings[_pending_key()] + assert "LOOP DETECTED" in mw._pending_warnings[_pending_key()][0] + + def test_warn_injected_at_next_model_call(self): + """``wrap_model_call`` appends a HumanMessage(loop_warning) to the + outgoing messages — *after* every existing message — so that the + AIMessage(tool_calls=...) -> ToolMessage(...) pairing stays intact. """ mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) runtime = _make_runtime() call = [_bash_call("ls")] + for _ in range(3): + mw._apply(_make_state(tool_calls=call), runtime) - for _ in range(2): + # Build the messages the agent runtime would assemble for the next + # turn: prior AIMessage(tool_calls), its ToolMessage responses, ... + ai_msg = AIMessage(content="", tool_calls=call) + tool_msg = ToolMessage(content="ok", tool_call_id=call[0]["id"], name="bash") + request = _make_request([ai_msg, tool_msg], runtime) + + captured, handler = _capture_handler() + mw.wrap_model_call(request, handler) + + sent = captured[0].messages + # AIMessage and ToolMessage stay in order, untouched. + assert sent[0] is ai_msg + assert sent[1] is tool_msg + # HumanMessage(warning) appears AFTER the ToolMessage — pairing intact. + assert isinstance(sent[2], HumanMessage) + assert sent[2].name == "loop_warning" + assert "LOOP DETECTED" in sent[2].content + + def test_warn_queue_drained_after_injection(self): + """A queued warning must be emitted exactly once per detection event.""" + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + runtime = _make_runtime() + call = [_bash_call("ls")] + for _ in range(3): mw._apply(_make_state(tool_calls=call), runtime) - result = mw._apply(_make_state(tool_calls=call), runtime) - assert result is not None - msgs = result["messages"] - assert len(msgs) == 1 - assert isinstance(msgs[0], AIMessage) - assert len(msgs[0].tool_calls) == len(call) - assert msgs[0].tool_calls[0]["id"] == call[0]["id"] + request = _make_request([AIMessage(content="hi")], runtime) + captured, handler = _capture_handler() + + # First call: warning is appended. + mw.wrap_model_call(request, handler) + first = captured[0].messages + assert any(isinstance(m, HumanMessage) for m in first) + + # Subsequent call without new detection: no warning re-emitted. + request2 = _make_request([AIMessage(content="hi")], runtime) + mw.wrap_model_call(request2, handler) + second = captured[1].messages + assert not any(isinstance(m, HumanMessage) for m in second) + + def test_warn_queue_scoped_by_run_id(self): + """A warning queued for one run must not be injected into another run.""" + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + runtime_a = _make_runtime(run_id="run-A") + runtime_b = _make_runtime(run_id="run-B") + call = [_bash_call("ls")] + + for _ in range(3): + mw._apply(_make_state(tool_calls=call), runtime_a) + + request_b = _make_request([AIMessage(content="hi")], runtime_b) + captured, handler = _capture_handler() + mw.wrap_model_call(request_b, handler) + assert not any(isinstance(m, HumanMessage) for m in captured[0].messages) + assert mw._pending_warnings.get(_pending_key(run_id="run-A")) + + request_a = _make_request([AIMessage(content="hi")], runtime_a) + mw.wrap_model_call(request_a, handler) + assert any(isinstance(message, HumanMessage) and message.name == "loop_warning" for message in captured[1].messages) + + def test_missing_run_id_uses_default_pending_scope(self): + """When runtime has no run_id, warning handling falls back to the default run scope.""" + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + runtime = MagicMock() + runtime.context = {"thread_id": "test-thread"} + call = [_bash_call("ls")] + + for _ in range(3): + mw._apply(_make_state(tool_calls=call), runtime) + + assert mw._pending_warnings.get(_pending_key(run_id="default")) + + request = _make_request([AIMessage(content="hi")], runtime) + captured, handler = _capture_handler() + mw.wrap_model_call(request, handler) + + loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] + assert len(loop_warnings) == 1 + assert "LOOP DETECTED" in loop_warnings[0].content + assert not mw._pending_warnings.get(_pending_key(run_id="default")) - def test_warn_only_injected_once(self): - """Warning for the same hash should only be injected once per thread.""" + def test_before_agent_clears_stale_pending_warnings_for_thread(self): + """Starting a new run drops stale warnings from prior runs in the same thread.""" + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + runtime_a = _make_runtime(run_id="run-A") + runtime_b = _make_runtime(run_id="run-B") + call = [_bash_call("ls")] + + for _ in range(3): + mw._apply(_make_state(tool_calls=call), runtime_a) + + assert mw._pending_warnings.get(_pending_key(run_id="run-A")) + mw.before_agent({"messages": []}, runtime_b) + assert not mw._pending_warnings.get(_pending_key(run_id="run-A")) + + def test_after_agent_clears_current_run_pending_warnings(self): + """Run cleanup should drop warnings that never reached wrap_model_call.""" + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + runtime = _make_runtime() + call = [_bash_call("ls")] + + for _ in range(3): + mw._apply(_make_state(tool_calls=call), runtime) + + assert mw._pending_warnings.get(_pending_key()) + mw.after_agent({"messages": []}, runtime) + assert not mw._pending_warnings.get(_pending_key()) + + def test_multiple_pending_warnings_are_merged_into_one_message(self): + """Edge-case drains should produce one loop_warning prompt message.""" + mw = LoopDetectionMiddleware() + runtime = _make_runtime() + mw._pending_warnings[_pending_key()] = ["first warning", "second warning", "first warning"] + request = _make_request([AIMessage(content="hi")], runtime) + captured, handler = _capture_handler() + + mw.wrap_model_call(request, handler) + + loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] + assert len(loop_warnings) == 1 + assert loop_warnings[0].content == "first warning\n\nsecond warning" + + def test_warn_only_queued_once_per_hash(self): + """Same hash repeated past the threshold should warn only once.""" mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) runtime = _make_runtime() call = [_bash_call("ls")] @@ -192,14 +339,13 @@ def test_warn_only_injected_once(self): for _ in range(2): mw._apply(_make_state(tool_calls=call), runtime) - # Third — warning injected - result = mw._apply(_make_state(tool_calls=call), runtime) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + # Third — warning queued + mw._apply(_make_state(tool_calls=call), runtime) + assert len(mw._pending_warnings[_pending_key()]) == 1 - # Fourth — warning already injected, should return None - result = mw._apply(_make_state(tool_calls=call), runtime) - assert result is None + # Fourth — already warned for this hash, no additional enqueue. + mw._apply(_make_state(tool_calls=call), runtime) + assert len(mw._pending_warnings[_pending_key()]) == 1 def test_hard_stop_at_limit(self): mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4) @@ -257,6 +403,7 @@ def test_reset_clears_state(self): mw.reset() result = mw._apply(_make_state(tool_calls=call), runtime) assert result is None + assert not mw._pending_warnings.get(_pending_key()) def test_non_ai_message_ignored(self): mw = LoopDetectionMiddleware() @@ -283,15 +430,16 @@ def test_thread_id_from_runtime_context(self): # One call on thread B mw._apply(_make_state(tool_calls=call), runtime_b) - # Second call on thread A — triggers warning (2 >= warn_threshold) - result = mw._apply(_make_state(tool_calls=call), runtime_a) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + # Second call on thread A — queues warning under thread-A only. + mw._apply(_make_state(tool_calls=call), runtime_a) + assert mw._pending_warnings.get(_pending_key("thread-A")) + assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0] + assert not mw._pending_warnings.get(_pending_key("thread-B")) - # Second call on thread B — also triggers (independent tracking) - result = mw._apply(_make_state(tool_calls=call), runtime_b) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + # Second call on thread B — independent queue. + mw._apply(_make_state(tool_calls=call), runtime_b) + assert mw._pending_warnings.get(_pending_key("thread-B")) + assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0] def test_lru_eviction(self): """Old threads should be evicted when max_tracked_threads is exceeded.""" @@ -507,33 +655,29 @@ def test_freq_warn_at_threshold(self): for i in range(4): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) - # 5th call to read_file (different file each time) triggers freq warning + # 5th call queues a per-tool-type frequency warning; state untouched. result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime) - assert result is not None - msg = result["messages"][0] - # Warning is appended to the AIMessage content; tool_calls preserved - # so the tools node still runs and Moonshot/OpenAI tool-call pairing - # validation does not break. - assert isinstance(msg, AIMessage) - assert msg.tool_calls - assert "read_file" in msg.content - assert "LOOP DETECTED" in msg.content + assert result is None + queued = mw._pending_warnings.get(_pending_key(), []) + assert queued + assert "read_file" in queued[0] + assert "LOOP DETECTED" in queued[0] - def test_freq_warn_only_injected_once(self): + def test_freq_warn_only_queued_once(self): mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) runtime = _make_runtime() for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) - # 3rd triggers warning - result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + # 3rd queues a frequency warning. + mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) + assert len(mw._pending_warnings[_pending_key()]) == 1 - # 4th should not re-warn (already warned for read_file) + # 4th: same tool name, no additional enqueue. result = mw._apply(_make_state(tool_calls=[self._read_call("/file_3.py")]), runtime) assert result is None + assert len(mw._pending_warnings[_pending_key()]) == 1 def test_freq_hard_stop_at_limit(self): mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=6) @@ -565,10 +709,10 @@ def test_different_tools_tracked_independently(self): result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime) assert result is None - # 3rd read_file triggers (read_file count = 3) + # 3rd read_file triggers — warning is queued (state unchanged). result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) - assert result is not None - assert "read_file" in result["messages"][0].content + assert result is None + assert "read_file" in mw._pending_warnings[_pending_key()][0] def test_freq_reset_clears_state(self): mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) @@ -600,10 +744,10 @@ def test_freq_reset_per_thread_clears_only_target(self): assert "thread-A" not in mw._tool_freq assert "thread-A" not in mw._tool_freq_warned - # thread-B state should still be intact — 3rd call triggers warn + # thread-B state should still be intact — 3rd call queues a warn. result = mw._apply(_make_state(tool_calls=[self._read_call("/b_2.py")]), runtime_b) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + assert result is None + assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0] # thread-A restarted from 0 — should not trigger result = mw._apply(_make_state(tool_calls=[self._read_call("/a_new.py")]), runtime_a) @@ -623,10 +767,11 @@ def test_freq_per_thread_isolation(self): for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/other_{i}.py")]), runtime_b) - # 3rd call on thread A — triggers (count=3 for thread A only) + # 3rd call on thread A — queues a warning (count=3 for thread A only). result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime_a) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + assert result is None + assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0] + assert not mw._pending_warnings.get(_pending_key("thread-B")) def test_multi_tool_single_response_counted(self): """When a single response has multiple tool calls, each is counted.""" @@ -643,10 +788,10 @@ def test_multi_tool_single_response_counted(self): result = mw._apply(_make_state(tool_calls=call), runtime) assert result is None - # Response 3: 1 more → count = 5 → triggers warn + # Response 3: 1 more → count = 5 → queues warn. result = mw._apply(_make_state(tool_calls=[self._read_call("/e.py")]), runtime) - assert result is not None - assert "read_file" in result["messages"][0].content + assert result is None + assert "read_file" in mw._pending_warnings[_pending_key()][0] def test_override_tool_uses_override_thresholds(self): """A tool in tool_freq_overrides uses its own thresholds, not the global ones.""" @@ -674,10 +819,14 @@ def test_non_override_tool_falls_back_to_global(self): for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) - # 3rd read_file call hits global warn=3 (read_file has no override) + # 3rd read_file call hits global warn=3 (read_file has no override). + # Warning delivery is deferred to wrap_model_call so the just-emitted + # AIMessage(tool_calls=...) is not mutated before ToolMessages exist. result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) - assert result is not None - assert "read_file" in result["messages"][0].content + assert result is None + queued = mw._pending_warnings.get(_pending_key(), []) + assert queued + assert "read_file" in queued[0] def test_hash_detection_takes_priority(self): """Hash-based hard stop fires before frequency check for identical calls.""" @@ -736,11 +885,13 @@ def test_empty_overrides(self): mw = LoopDetectionMiddleware.from_config(self._config()) assert mw._tool_freq_overrides == {} - def test_constructed_middleware_detects_loops(self): + def test_constructed_middleware_queues_loop_warning(self): mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4)) runtime = _make_runtime() call = [_bash_call("ls")] mw._apply(_make_state(tool_calls=call), runtime) result = mw._apply(_make_state(tool_calls=call), runtime) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + assert result is None + queued = mw._pending_warnings.get(_pending_key(), []) + assert queued + assert "LOOP DETECTED" in queued[0] diff --git a/frontend/src/content/en/harness/middlewares.mdx b/frontend/src/content/en/harness/middlewares.mdx index 72195664d4..389b9881cc 100644 --- a/frontend/src/content/en/harness/middlewares.mdx +++ b/frontend/src/content/en/harness/middlewares.mdx @@ -50,6 +50,8 @@ Intercepts clarification tool calls and converts them into proper user-facing re Detects when the agent is making the same tool call repeatedly without making progress. When a loop is detected, the middleware intervenes to break the cycle and prevents the agent from burning turns indefinitely. +Warning interventions are queued per thread and run, then drained on the next model call as a single hidden `HumanMessage(name="loop_warning")` appended after existing tool results. This keeps provider tool-call pairing valid. Run start/end hooks clear stale or undelivered warnings, and hard stops still strip tool calls before forcing a final text response. + **Configuration**: built-in, no user configuration. --- diff --git a/frontend/src/content/zh/harness/middlewares.mdx b/frontend/src/content/zh/harness/middlewares.mdx index 051729b947..9e81caa3e8 100644 --- a/frontend/src/content/zh/harness/middlewares.mdx +++ b/frontend/src/content/zh/harness/middlewares.mdx @@ -50,6 +50,8 @@ import { Callout } from "nextra/components"; 检测 Agent 是否在没有取得进展的情况下重复进行相同的工具调用。检测到循环时,中间件会介入打破循环,防止 Agent 无限消耗轮次。 +Warning 介入会按 thread 和 run 排队,并在下一次模型调用时合并为一条隐藏的 `HumanMessage(name="loop_warning")`,追加到已有工具结果之后。这样不会破坏 provider 对 tool-call/tool-message 配对的校验。Run 开始和结束时会清理过期或未送达的 warning;达到 hard stop 时仍会清空 tool calls 并强制生成最终文本回复。 + **配置**:内置,无需用户配置。 ---