Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions backend/app/gateway/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
from app.gateway.utils import sanitize_log_param
from deerflow.config.app_config import get_app_config
from deerflow.runtime import (
END_SENTINEL,
HEARTBEAT_SENTINEL,
Expand Down Expand Up @@ -249,6 +250,23 @@ async def start_run(

disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_

body_context = getattr(body, "context", None) or {}
model_name = body_context.get("model_name")

Comment on lines +253 to +255
# Truncate to 128 chars to match DB column constraint (model.py:23).
if model_name and len(model_name) > 128:
model_name = model_name[:128]

# Validate model against the allowlist when a model_name is provided.
if model_name:
app_config = get_app_config()
resolved = app_config.get_model_config(model_name)
if resolved is None:
raise HTTPException(
status_code=400,
detail=f"Model {model_name!r} is not in the configured model allowlist",
)

try:
record = await run_mgr.create_or_reject(
thread_id,
Expand All @@ -257,6 +275,7 @@ async def start_run(
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
model_name=model_name,
)
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
Expand Down
14 changes: 14 additions & 0 deletions backend/packages/harness/deerflow/persistence/run/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ class RunRepository(RunStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory

@staticmethod
def _normalize_model_name(model_name: str | None) -> str | None:
"""Normalize model_name for storage: strip whitespace, truncate to 128 chars."""
if model_name is None:
return None
if not isinstance(model_name, str):
model_name = str(model_name)
normalized = model_name.strip()
if len(normalized) > 128:
normalized = normalized[:128]
return normalized

@staticmethod
def _safe_json(obj: Any) -> Any:
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
Expand Down Expand Up @@ -70,6 +82,7 @@ async def put(
thread_id,
assistant_id=None,
user_id: str | None | _AutoSentinel = AUTO,
model_name: str | None = None,
status="pending",
multitask_strategy="reject",
metadata=None,
Expand All @@ -85,6 +98,7 @@ async def put(
thread_id=thread_id,
assistant_id=assistant_id,
user_id=resolved_user_id,
model_name=self._normalize_model_name(model_name),
status=status,
multitask_strategy=multitask_strategy,
metadata_json=self._safe_json(metadata) or {},
Expand Down
4 changes: 4 additions & 0 deletions backend/packages/harness/deerflow/runtime/runs/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class RunRecord:
abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
abort_action: str = "interrupt"
error: str | None = None
model_name: str | None = None


class RunManager:
Expand Down Expand Up @@ -65,6 +66,7 @@ async def _persist_to_store(self, record: RunRecord) -> None:
metadata=record.metadata or {},
kwargs=record.kwargs or {},
created_at=record.created_at,
model_name=record.model_name,
)
except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
Expand Down Expand Up @@ -171,6 +173,7 @@ async def create_or_reject(
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
model_name: str | None = None,
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.

Expand Down Expand Up @@ -221,6 +224,7 @@ async def create_or_reject(
kwargs=kwargs or {},
created_at=now,
updated_at=now,
model_name=model_name,
)
self._runs[run_id] = record

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ async def put(
thread_id: str,
assistant_id: str | None = None,
user_id: str | None = None,
model_name: str | None = None,
status: str = "pending",
multitask_strategy: str = "reject",
metadata: dict[str, Any] | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ async def put(
thread_id,
assistant_id=None,
user_id=None,
model_name=None,
status="pending",
multitask_strategy="reject",
metadata=None,
Expand All @@ -35,6 +36,7 @@ async def put(
"thread_id": thread_id,
"assistant_id": assistant_id,
"user_id": user_id,
"model_name": model_name,
"status": status,
"multitask_strategy": multitask_strategy,
"metadata": metadata or {},
Expand Down
12 changes: 12 additions & 0 deletions backend/packages/harness/deerflow/runtime/runs/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,18 @@ async def run_agent(
else:
agent = agent_factory(config=runnable_config)

# Capture the effective (resolved) model name from the agent's metadata.
# _resolve_model_name in agent.py may return the default model if the
# requested name is not in the allowlist — this update ensures the
# persisted model_name reflects the actual model used.
if record.model_name is not None:
resolved = getattr(agent, "metadata", {}) or {}
if isinstance(resolved, dict):
effective = resolved.get("model_name")
if effective and effective != record.model_name:
record.model_name = effective
await run_manager._persist_to_store(record)

# 4. Attach checkpointer and store
if checkpointer is not None:
agent.checkpointer = checkpointer
Expand Down
51 changes: 51 additions & 0 deletions backend/tests/test_run_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from deerflow.runtime import RunManager, RunStatus
from deerflow.runtime.runs.store.memory import MemoryRunStore

ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")

Expand Down Expand Up @@ -141,3 +142,53 @@ async def test_create_defaults(manager: RunManager):
assert record.kwargs == {}
assert record.multitask_strategy == "reject"
assert record.assistant_id is None


@pytest.mark.anyio
async def test_model_name_create_or_reject():
"""create_or_reject should accept and persist model_name."""
from deerflow.runtime.runs.schemas import DisconnectMode

store = MemoryRunStore()
mgr = RunManager(store=store)

record = await mgr.create_or_reject(
"thread-1",
assistant_id="lead_agent",
on_disconnect=DisconnectMode.cancel,
metadata={"key": "val"},
kwargs={"input": {}},
multitask_strategy="reject",
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
)
assert record.model_name == "anthropic.claude-sonnet-4-20250514-v1:0"
assert record.status == RunStatus.pending

# Verify model_name was persisted to store
stored = await store.get(record.run_id)
assert stored is not None
assert stored["model_name"] == "anthropic.claude-sonnet-4-20250514-v1:0"

# Verify retrieval returns the model_name via in-memory record
fetched = mgr.get(record.run_id)
assert fetched is not None
assert fetched.model_name == "anthropic.claude-sonnet-4-20250514-v1:0"


@pytest.mark.anyio
async def test_model_name_default_is_none():
"""create_or_reject without model_name should default to None."""
from deerflow.runtime.runs.schemas import DisconnectMode

store = MemoryRunStore()
mgr = RunManager(store=store)

record = await mgr.create_or_reject(
"thread-1",
on_disconnect=DisconnectMode.cancel,
model_name=None,
)
assert record.model_name is None

stored = await store.get(record.run_id)
assert stored["model_name"] is None
Loading