Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions backend/packages/harness/deerflow/agents/lead_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ def _resolve_model_name(requested_model_name: str | None = None, *, app_config:
return default_model_name


def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> DeerFlowSummarizationMiddleware | None:
def _create_summarization_middleware(*, app_config: AppConfig) -> DeerFlowSummarizationMiddleware | None:
"""Create and configure the summarization middleware from config."""
resolved_app_config = app_config or get_app_config()
config = resolved_app_config.summarization
config = app_config.summarization

if not config.enabled:
return None
Expand All @@ -74,9 +73,9 @@ def _create_summarization_middleware(*, app_config: AppConfig | None = None) ->
# as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time).
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=resolved_app_config)
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config)
else:
model = create_chat_model(thinking_enabled=False, app_config=resolved_app_config)
model = create_chat_model(thinking_enabled=False, app_config=app_config)
model = model.with_config(tags=["middleware:summarize"])

# Prepare kwargs
Expand All @@ -93,13 +92,13 @@ def _create_summarization_middleware(*, app_config: AppConfig | None = None) ->
kwargs["summary_prompt"] = config.summary_prompt

hooks: list[BeforeSummarizationHook] = []
if resolved_app_config.memory.enabled:
if app_config.memory.enabled:
hooks.append(memory_flush_hook)

# The logic below relies on two assumptions holding true: this factory is
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
# config is not expected to change after startup.
skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills"
skills_container_path = app_config.skills.container_path or "/mnt/skills"

return DeerFlowSummarizationMiddleware(
**kwargs,
Expand Down Expand Up @@ -243,7 +242,7 @@ def _build_middlewares(
agent_name: str | None = None,
custom_middlewares: list[AgentMiddleware] | None = None,
*,
app_config: AppConfig | None = None,
app_config: AppConfig,
):
"""Build middleware chain based on runtime configuration.

Expand All @@ -255,11 +254,10 @@ def _build_middlewares(
Returns:
List of middleware instances.
"""
resolved_app_config = app_config or get_app_config()
middlewares = build_lead_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
middlewares = build_lead_runtime_middlewares(app_config=app_config, lazy_init=True)

# Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
summarization_middleware = _create_summarization_middleware(app_config=app_config)
if summarization_middleware is not None:
middlewares.append(summarization_middleware)

Expand All @@ -271,23 +269,23 @@ def _build_middlewares(
middlewares.append(todo_list_middleware)

# Add TokenUsageMiddleware when token_usage tracking is enabled
if resolved_app_config.token_usage.enabled:
if app_config.token_usage.enabled:
middlewares.append(TokenUsageMiddleware())

# Add TitleMiddleware
middlewares.append(TitleMiddleware(app_config=resolved_app_config))
middlewares.append(TitleMiddleware(app_config=app_config))

# Add MemoryMiddleware (after TitleMiddleware)
middlewares.append(MemoryMiddleware(agent_name=agent_name, memory_config=resolved_app_config.memory))
middlewares.append(MemoryMiddleware(agent_name=agent_name, memory_config=app_config.memory))

# Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
model_config = resolved_app_config.get_model_config(model_name) if model_name else None
model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware())

# Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
if resolved_app_config.tool_search.enabled:
if app_config.tool_search.enabled:
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware

middlewares.append(DeferredToolFilterMiddleware())
Expand All @@ -299,7 +297,7 @@ def _build_middlewares(
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))

# LoopDetectionMiddleware — detect and break repetitive tool call loops
loop_detection_config = resolved_app_config.loop_detection
loop_detection_config = app_config.loop_detection
if loop_detection_config.enabled:
middlewares.append(LoopDetectionMiddleware.from_config(loop_detection_config))

Expand Down
8 changes: 7 additions & 1 deletion backend/packages/harness/deerflow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,13 @@ def _ensure_agent(self, config: RunnableConfig):
kwargs: dict[str, Any] = {
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
"middleware": _build_middlewares(
config,
model_name=model_name,
agent_name=self._agent_name,
custom_middlewares=self._middlewares,
app_config=self._app_config,
),
"system_prompt": apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
Expand Down
19 changes: 14 additions & 5 deletions backend/packages/harness/deerflow/runtime/runs/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
from dataclasses import dataclass, field
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, NotRequired, Required, TypedDict, cast

from langgraph.checkpoint.base import empty_checkpoint

Expand All @@ -41,12 +41,21 @@
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}


class DeerFlowRuntimeContext(TypedDict, total=False):
"""Typed shape of the runtime context dict passed to ``ToolRuntime.context``."""

thread_id: Required[str]
run_id: Required[str]
app_config: NotRequired[AppConfig]
agent_name: NotRequired[str]


def _build_runtime_context(
thread_id: str,
run_id: str,
caller_context: Any | None,
app_config: AppConfig | None = None,
) -> dict[str, Any]:
) -> DeerFlowRuntimeContext:
"""Build the dict that becomes ``ToolRuntime.context`` for the run.

Always includes ``thread_id`` and ``run_id``. Additional keys from the caller's
Expand All @@ -59,7 +68,7 @@ def _build_runtime_context(
under ``config['configurable']['__pregel_runtime']`` — see
``langgraph.pregel.main`` where ``parent_runtime.merge(...)`` is invoked.
"""
runtime_ctx: dict[str, Any] = {"thread_id": thread_id, "run_id": run_id}
runtime_ctx: DeerFlowRuntimeContext = {"thread_id": thread_id, "run_id": run_id}
if isinstance(caller_context, dict):
for key, value in caller_context.items():
runtime_ctx.setdefault(key, value)
Expand All @@ -85,7 +94,7 @@ class RunContext:
app_config: AppConfig | None = field(default=None)


def _install_runtime_context(config: dict, runtime_context: dict[str, Any]) -> None:
def _install_runtime_context(config: dict, runtime_context: DeerFlowRuntimeContext) -> None:
existing_context = config.get("context")
if isinstance(existing_context, dict):
existing_context.setdefault("thread_id", runtime_context["thread_id"])
Expand Down Expand Up @@ -216,7 +225,7 @@ async def run_agent(
# without passing the official ``context=`` parameter.
runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config)
_install_runtime_context(config, runtime_ctx)
runtime = Runtime(context=cast(Any, runtime_ctx), store=store)
runtime = Runtime(context=cast(Any, runtime_ctx), store=store) # TODO(#2687): cast retained because Runtime.context expects Any and TypedDict is not assignable without it
config.setdefault("configurable", {})["__pregel_runtime"] = runtime

# Inject RunJournal as a LangChain callback handler.
Expand Down
8 changes: 5 additions & 3 deletions backend/packages/harness/deerflow/tools/builtins/task_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import uuid
from dataclasses import replace
from typing import TYPE_CHECKING, Annotated, Any, cast
from typing import TYPE_CHECKING, Annotated, Any

from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langgraph.config import get_stream_writer
Expand All @@ -24,16 +24,18 @@

if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
from deerflow.runtime.runs.worker import DeerFlowRuntimeContext

logger = logging.getLogger(__name__)


def _get_runtime_app_config(runtime: Any) -> "AppConfig | None":
context = getattr(runtime, "context", None)
if isinstance(context, dict):
app_config = context.get("app_config")
typed_context: DeerFlowRuntimeContext = context # pyright: ignore[reportAssignmentType]
app_config = typed_context.get("app_config")
if app_config is not None:
return cast("AppConfig", app_config)
return app_config
return None


Expand Down
1 change: 1 addition & 0 deletions backend/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ def test_creates_agent(self, client):
# Verify agent_name propagation
mock_build_middlewares.assert_called_once()
assert mock_build_middlewares.call_args.kwargs.get("agent_name") == "custom-agent"
assert mock_build_middlewares.call_args.kwargs.get("app_config") is client._app_config
mock_apply_prompt.assert_called_once()
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"}
Expand Down
24 changes: 17 additions & 7 deletions backend/tests/test_lead_agent_model_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, ap
assert result["model"] is not None


def test_create_summarization_middleware_requires_explicit_app_config():
create_summarization_middleware = getattr(lead_agent_module, "_create_summarization_middleware")

with pytest.raises(TypeError):
create_summarization_middleware()


def test_make_lead_agent_uses_runtime_app_config_from_context_without_global_read(monkeypatch):
app_config = _make_app_config([_make_model("context-model", supports_thinking=False)])

Expand Down Expand Up @@ -430,10 +437,10 @@ def _raise_get_app_config():
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])


def test_create_summarization_middleware_threads_resolved_app_config_to_model(monkeypatch):
fallback_app_config = _make_app_config([_make_model("fallback-model", supports_thinking=False)])
fallback_app_config.summarization = SummarizationConfig(enabled=True, model_name="fallback-model")
fallback_app_config.memory = MemoryConfig(enabled=False)
def test_create_summarization_middleware_threads_explicit_app_config_to_model(monkeypatch):
app_config = _make_app_config([_make_model("explicit-model", supports_thinking=False)])
app_config.summarization = SummarizationConfig(enabled=True, model_name="explicit-model")
app_config.memory = MemoryConfig(enabled=False)

from unittest.mock import MagicMock

Expand All @@ -445,13 +452,16 @@ def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=Non
captured["app_config"] = app_config
return fake_model

monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: fallback_app_config)
def _raise_get_app_config():
raise AssertionError("ambient get_app_config() must not be used by summarization middleware")

monkeypatch.setattr(lead_agent_module, "get_app_config", _raise_get_app_config)
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)

lead_agent_module._create_summarization_middleware()
lead_agent_module._create_summarization_middleware(app_config=app_config)

assert captured["app_config"] is fallback_app_config
assert captured["app_config"] is app_config


def test_memory_middleware_uses_explicit_memory_config_without_global_read(monkeypatch):
Expand Down