diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 634b8b9d14..0b25d7b6eb 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -19,6 +19,7 @@ from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge from app.gateway.utils import sanitize_log_param +from deerflow.config.app_config import get_app_config from deerflow.runtime import ( END_SENTINEL, HEARTBEAT_SENTINEL, @@ -249,6 +250,23 @@ async def start_run( disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_ + body_context = getattr(body, "context", None) or {} + model_name = body_context.get("model_name") + + # Coerce non-string model_name values to str before truncation. + if model_name is not None and not isinstance(model_name, str): + model_name = str(model_name) + + # Validate model against the allowlist when a model_name is provided. + if model_name: + app_config = get_app_config() + resolved = app_config.get_model_config(model_name) + if resolved is None: + raise HTTPException( + status_code=400, + detail=f"Model {model_name!r} is not in the configured model allowlist", + ) + try: record = await run_mgr.create_or_reject( thread_id, @@ -257,6 +275,7 @@ async def start_run( metadata=body.metadata or {}, kwargs={"input": body.input, "config": body.config}, multitask_strategy=body.multitask_strategy, + model_name=model_name, ) except ConflictError as exc: raise HTTPException(status_code=409, detail=str(exc)) from exc diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index fcd1a34115..430fbe4f62 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -23,6 +23,18 @@ class RunRepository(RunStore): def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: self._sf = session_factory + @staticmethod + def _normalize_model_name(model_name: str | None) -> str | None: + """Normalize model_name for storage: strip whitespace, truncate to 128 chars.""" + if model_name is None: + return None + if not isinstance(model_name, str): + model_name = str(model_name) + normalized = model_name.strip() + if len(normalized) > 128: + normalized = normalized[:128] + return normalized + @staticmethod def _safe_json(obj: Any) -> Any: """Ensure obj is JSON-serializable. Falls back to model_dump() or str().""" @@ -70,6 +82,7 @@ async def put( thread_id, assistant_id=None, user_id: str | None | _AutoSentinel = AUTO, + model_name: str | None = None, status="pending", multitask_strategy="reject", metadata=None, @@ -85,6 +98,7 @@ async def put( thread_id=thread_id, assistant_id=assistant_id, user_id=resolved_user_id, + model_name=self._normalize_model_name(model_name), status=status, multitask_strategy=multitask_strategy, metadata_json=self._safe_json(metadata) or {}, diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index 533342c872..50dc594abf 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -36,6 +36,7 @@ class RunRecord: abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False) abort_action: str = "interrupt" error: str | None = None + model_name: str | None = None class RunManager: @@ -65,6 +66,7 @@ async def _persist_to_store(self, record: RunRecord) -> None: metadata=record.metadata or {}, kwargs=record.kwargs or {}, created_at=record.created_at, + model_name=record.model_name, ) except Exception: logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True) @@ -137,6 +139,18 @@ async def set_status(self, run_id: str, status: RunStatus, *, error: str | None logger.warning("Failed to persist status update for run %s", run_id, exc_info=True) logger.info("Run %s -> %s", run_id, status.value) + async def update_model_name(self, run_id: str, model_name: str | None) -> None: + """Update the model name for a run.""" + async with self._lock: + record = self._runs.get(run_id) + if record is None: + logger.warning("update_model_name called for unknown run %s", run_id) + return + record.model_name = model_name + record.updated_at = _now_iso() + await self._persist_to_store(record) + logger.info("Run %s model_name=%s", run_id, model_name) + async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool: """Request cancellation of a run. @@ -171,6 +185,7 @@ async def create_or_reject( metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", + model_name: str | None = None, ) -> RunRecord: """Atomically check for inflight runs and create a new one. @@ -221,6 +236,7 @@ async def create_or_reject( kwargs=kwargs or {}, created_at=now, updated_at=now, + model_name=model_name, ) self._runs[run_id] = record diff --git a/backend/packages/harness/deerflow/runtime/runs/store/base.py b/backend/packages/harness/deerflow/runtime/runs/store/base.py index 518a1903c3..d3c10eba6e 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/base.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/base.py @@ -23,6 +23,7 @@ async def put( thread_id: str, assistant_id: str | None = None, user_id: str | None = None, + model_name: str | None = None, status: str = "pending", multitask_strategy: str = "reject", metadata: dict[str, Any] | None = None, diff --git a/backend/packages/harness/deerflow/runtime/runs/store/memory.py b/backend/packages/harness/deerflow/runtime/runs/store/memory.py index 5a14af3dff..e41147e3ea 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/memory.py @@ -22,6 +22,7 @@ async def put( thread_id, assistant_id=None, user_id=None, + model_name=None, status="pending", multitask_strategy="reject", metadata=None, @@ -35,6 +36,7 @@ async def put( "thread_id": thread_id, "assistant_id": assistant_id, "user_id": user_id, + "model_name": model_name, "status": status, "multitask_strategy": multitask_strategy, "metadata": metadata or {}, diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 2aecb9a1b3..f78d425a25 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -230,6 +230,17 @@ async def run_agent( else: agent = agent_factory(config=runnable_config) + # Capture the effective (resolved) model name from the agent's metadata. + # _resolve_model_name in agent.py may return the default model if the + # requested name is not in the allowlist — this update ensures the + # persisted model_name reflects the actual model used. + if record.model_name is not None: + resolved = getattr(agent, "metadata", {}) or {} + if isinstance(resolved, dict): + effective = resolved.get("model_name") + if effective and effective != record.model_name: + await run_manager.update_model_name(record.run_id, effective) + # 4. Attach checkpointer and store if checkpointer is not None: agent.checkpointer = checkpointer diff --git a/backend/tests/test_run_manager.py b/backend/tests/test_run_manager.py index 58ecf1f26c..98cd582640 100644 --- a/backend/tests/test_run_manager.py +++ b/backend/tests/test_run_manager.py @@ -5,6 +5,7 @@ import pytest from deerflow.runtime import RunManager, RunStatus +from deerflow.runtime.runs.store.memory import MemoryRunStore ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") @@ -141,3 +142,53 @@ async def test_create_defaults(manager: RunManager): assert record.kwargs == {} assert record.multitask_strategy == "reject" assert record.assistant_id is None + + +@pytest.mark.anyio +async def test_model_name_create_or_reject(): + """create_or_reject should accept and persist model_name.""" + from deerflow.runtime.runs.schemas import DisconnectMode + + store = MemoryRunStore() + mgr = RunManager(store=store) + + record = await mgr.create_or_reject( + "thread-1", + assistant_id="lead_agent", + on_disconnect=DisconnectMode.cancel, + metadata={"key": "val"}, + kwargs={"input": {}}, + multitask_strategy="reject", + model_name="anthropic.claude-sonnet-4-20250514-v1:0", + ) + assert record.model_name == "anthropic.claude-sonnet-4-20250514-v1:0" + assert record.status == RunStatus.pending + + # Verify model_name was persisted to store + stored = await store.get(record.run_id) + assert stored is not None + assert stored["model_name"] == "anthropic.claude-sonnet-4-20250514-v1:0" + + # Verify retrieval returns the model_name via in-memory record + fetched = mgr.get(record.run_id) + assert fetched is not None + assert fetched.model_name == "anthropic.claude-sonnet-4-20250514-v1:0" + + +@pytest.mark.anyio +async def test_model_name_default_is_none(): + """create_or_reject without model_name should default to None.""" + from deerflow.runtime.runs.schemas import DisconnectMode + + store = MemoryRunStore() + mgr = RunManager(store=store) + + record = await mgr.create_or_reject( + "thread-1", + on_disconnect=DisconnectMode.cancel, + model_name=None, + ) + assert record.model_name is None + + stored = await store.get(record.run_id) + assert stored["model_name"] is None diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index 34ab9b492f..6b52aca811 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -194,3 +194,32 @@ async def test_owner_none_returns_all(self, tmp_path): rows = await repo.list_by_thread("t1", user_id=None) assert len(rows) == 2 await _cleanup() + + @pytest.mark.anyio + async def test_model_name_persistence(self, tmp_path): + """RunRepository should persist, normalize, and truncate model_name correctly via SQL.""" + from deerflow.persistence.engine import get_session_factory, init_engine + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + repo = RunRepository(get_session_factory()) + + await repo.put("run-1", thread_id="thread-1", model_name="gpt-4o") + row = await repo.get("run-1") + assert row is not None + assert row["model_name"] == "gpt-4o" + + long_name = "a" * 200 + await repo.put("run-2", thread_id="thread-1", model_name=long_name) + row2 = await repo.get("run-2") + assert row2["model_name"] == "a" * 128 + + await repo.put("run-3", thread_id="thread-1", model_name=123) + row3 = await repo.get("run-3") + assert row3["model_name"] == "123" + + await repo.put("run-4", thread_id="thread-1", model_name=None) + row4 = await repo.get("run-4") + assert row4["model_name"] is None + + await _cleanup()