Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@

def get_llm() -> BaseLanguageModel:
"""
FastAPI dependency that provides the singleton LLM instance.

This function is used as a dependency injection in FastAPI endpoints
to access the configured language model.

FastAPI dependency that provides the default LLM instance.

Returns:
The default language model instance.
"""
return LLMManager.get_instance()


def get_uitools_llm() -> BaseLanguageModel:
"""
FastAPI dependency that provides the UI tools LLM instance.

Returns:
The singleton language model instance.
The UI tools language model instance (falls back to default if not configured).
"""
return LLMManager.get_instance()
return LLMManager.get_instance_for_role("uitools")
4 changes: 2 additions & 2 deletions app/routers/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ async def update_settings(settings: SettingsUpdate, request: Request):
secret.data = secret_data
v1.patch_namespaced_secret(SETTINGS_SECRET_NAME, AGENT_NAMESPACE, secret)

# Reset LLMManager singleton to force reinitialization for consistency, but as of now we are redeploying the agent...
LLMManager._instance = None
# Reset LLMManager to force reinitialization for consistency, but as of now we are redeploying the agent...
LLMManager.reset()

# Re-fetch both resources to return the updated content
updated_cm = v1.read_namespaced_config_map(SETTINGS_CONFIGMAP_NAME, AGENT_NAMESPACE)
Expand Down
4 changes: 2 additions & 2 deletions app/routers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain_core.language_models.llms import BaseLanguageModel

from ..services.auth import get_user_id_from_request
from ..dependencies import get_llm
from ..dependencies import get_uitools_llm
from ..services.ui_tools.selector import create_ui_tools_selector

router = APIRouter(prefix="/v1/api/complete", tags=["llm"])
Expand All @@ -31,7 +31,7 @@ class UIToolsRequest(LLMRequest):
async def complete_ui_tools(
request: Request,
ui_tools_request: UIToolsRequest,
llm: BaseLanguageModel = Depends(get_llm)
llm: BaseLanguageModel = Depends(get_uitools_llm)
) -> JSONResponse:
"""
Select appropriate UI tools based on the provided context using the LLM.
Expand Down
8 changes: 3 additions & 5 deletions app/routers/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
import json
from datetime import datetime

from ..dependencies import get_llm
from ..services.agent.factory import NoAgentAvailableError, build_agent
from ..services.agent.supervisor import SupervisorGraph
from dataclasses import dataclass
from fastapi import APIRouter
from fastapi import WebSocket, WebSocketDisconnect, Depends
from fastapi import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState
from langgraph.graph.state import CompiledStateGraph
from langfuse.langchain import CallbackHandler
from langchain_core.language_models.llms import BaseLanguageModel
from langchain_core.messages import HumanMessage
from langgraph.types import Command

Expand Down Expand Up @@ -59,7 +57,7 @@ class WebSocketRequest:

@router.websocket("/v1/ws/messages")
@router.websocket("/v1/ws/messages/{thread_id}")
async def websocket_endpoint(websocket: WebSocket, thread_id: str = None, llm: BaseLanguageModel = Depends(get_llm)):
async def websocket_endpoint(websocket: WebSocket, thread_id: str = None):
"""
WebSocket endpoint for the agent.

Expand All @@ -77,7 +75,7 @@ async def websocket_endpoint(websocket: WebSocket, thread_id: str = None, llm: B
logging.debug("ws/messages connection opened")

try:
agent, agents_metadata = await build_agent(llm=llm, websocket=websocket)
agent, agents_metadata = await build_agent(websocket=websocket)
except NoAgentAvailableError as e:
logging.error(f"Error creating agent: {e}")
await websocket.send_text(f'<chat-error>{json.dumps({"message": str(e)})}</chat-error>')
Expand Down
6 changes: 4 additions & 2 deletions app/services/agent/child.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def create_child_agent(
system_prompt: str,
checkpointer: Checkpointer,
agent_config: AgentConfig,
summary_llm: BaseChatModel | None = None,
uitools_llm: BaseChatModel | None = None,
) -> CompiledStateGraph:
"""Create and compile a child agent graph using langchain create_agent with middleware.

Expand All @@ -57,8 +59,8 @@ def create_child_agent(
identity_preamble_middleware(),
cancel_human_validation_middleware(),
inject_additional_kwargs_middleware(),
ui_tools_middleware(llm, only_when_direct=True),
SummarizationMiddleware(model=llm, trigger=[("messages", 30), ("tokens", 30000)], keep=("messages", 15)),
ui_tools_middleware(uitools_llm or llm, only_when_direct=True),
SummarizationMiddleware(model=summary_llm or llm, trigger=[("messages", 30), ("tokens", 30000)], keep=("messages", 15)),
]

return create_agent(
Expand Down
28 changes: 20 additions & 8 deletions app/services/agent/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from langchain_core.language_models.llms import BaseLanguageModel
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.graph.state import Checkpointer, CompiledStateGraph
from ..llm import LLMManager

NAMESPACE = "cattle-ai-agent-system"

Expand All @@ -20,16 +21,16 @@ class NoAgentAvailableError(Exception):
pass


async def build_agent(llm: BaseLanguageModel, websocket: WebSocket) -> tuple[CompiledStateGraph | SupervisorGraph, list[dict]]:
async def build_agent(websocket: WebSocket) -> tuple[CompiledStateGraph | SupervisorGraph, list[dict]]:
"""
Build an agent graph from AIAgentConfig CRDs.

Loads agent configurations and creates either a supervisor (multi-agent) or a
single child agent depending on how many configs are available and successfully
connect to their MCP servers.
connect to their MCP servers. LLM instances are resolved per-role and per-agent
via LLMManager.

Args:
llm: The language model to use for agent reasoning.
websocket: WebSocket connection for authentication context.

Returns:
Expand All @@ -47,15 +48,21 @@ async def build_agent(llm: BaseLanguageModel, websocket: WebSocket) -> tuple[Com

logging.info(f"Loaded {len(agent_configs)} agent configuration(s)")

summary_llm = LLMManager.get_instance_for_role("summary")
uitools_llm = LLMManager.get_instance_for_role("uitools")

if len(agent_configs) == 1:
agent_cfg = agent_configs[0]
agent_llm = LLMManager.get_instance_for_agent(agent_cfg.name, agent_cfg.llm, agent_cfg.llm_model)
tools = await _load_mcp_tools(agent_cfg, websocket)
graph = create_child_agent(llm, tools, agent_cfg.system_prompt, checkpointer, agent_cfg)
graph = create_child_agent(agent_llm, tools, agent_cfg.system_prompt, checkpointer, agent_cfg,
summary_llm=summary_llm, uitools_llm=uitools_llm)
return graph, [{"name": agent_cfg.name, "status": "active"}]

# Multi-agent setup
logging.info(f"Multi-agent setup detected, creating supervisor with {len(agent_configs)} agents.")
child_agents, agents_metadata = await _build_child_agents(llm, agent_configs, checkpointer, websocket)
child_agents, agents_metadata = await _build_child_agents(agent_configs, checkpointer, websocket,
summary_llm=summary_llm, uitools_llm=uitools_llm)

if not child_agents:
logging.error("Failed to create any child agents due to MCP connection issues")
Expand All @@ -67,7 +74,9 @@ async def build_agent(llm: BaseLanguageModel, websocket: WebSocket) -> tuple[Com
logging.warning("Only one child agent connected successfully. Using it directly instead of a supervisor.")
return child_agents[0].agent, agents_metadata

graph = create_supervisor_agent(llm, child_agents, checkpointer)
supervisor_llm = LLMManager.get_instance_for_role("supervisor")
graph = create_supervisor_agent(supervisor_llm, child_agents, checkpointer,
summary_llm=summary_llm, uitools_llm=uitools_llm)
supervisor = SupervisorGraph(
graph=graph,
child_agents={ca.config.name: ca.agent for ca in child_agents},
Expand All @@ -76,10 +85,11 @@ async def build_agent(llm: BaseLanguageModel, websocket: WebSocket) -> tuple[Com


async def _build_child_agents(
llm: BaseLanguageModel,
agent_configs: list[AgentConfig],
checkpointer: Checkpointer,
websocket: WebSocket,
summary_llm: BaseLanguageModel | None = None,
uitools_llm: BaseLanguageModel | None = None,
) -> tuple[list[ChildAgent], list[dict]]:
"""
Attempt to build a child agent for each config, collecting successes and failures.
Expand All @@ -92,10 +102,12 @@ async def _build_child_agents(

for agent_cfg in agent_configs:
try:
agent_llm = LLMManager.get_instance_for_agent(agent_cfg.name, agent_cfg.llm, agent_cfg.llm_model)
tools = await _load_mcp_tools(agent_cfg, websocket)
child_agents.append(ChildAgent(
config=agent_cfg,
agent=create_child_agent(llm, tools, agent_cfg.system_prompt, checkpointer, agent_cfg),
agent=create_child_agent(agent_llm, tools, agent_cfg.system_prompt, checkpointer, agent_cfg,
summary_llm=summary_llm, uitools_llm=uitools_llm),
))
agents_metadata.append({"name": agent_cfg.name, "status": "active"})
except (NoAgentAvailableError, Exception) as e:
Expand Down
6 changes: 5 additions & 1 deletion app/services/agent/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class AgentConfig(BaseModel):
toolset: Optional[str] = None
human_validation_tools: list[str] = []
ready: bool = False
llm: Optional[str] = None
llm_model: Optional[str] = None


def _load_k8s_config():
Expand Down Expand Up @@ -319,7 +321,9 @@ def _crd_to_agent_config(crd_obj: dict) -> AgentConfig:
ca_bundle_ref=CABundleRef(**spec["caBundleRef"]) if spec.get("caBundleRef") else None,
toolset=spec.get("toolSet", None),
human_validation_tools=human_validation_tools,
ready=status.get("phase", "Failed") == "Ready"
ready=status.get("phase", "Failed") == "Ready",
llm=spec.get("llm", None),
llm_model=spec.get("llmModel", None),
)


Expand Down
8 changes: 6 additions & 2 deletions app/services/agent/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def create_supervisor_agent(
llm: BaseChatModel,
child_agents: list[ChildAgent],
checkpointer: Checkpointer,
summary_llm: BaseChatModel | None = None,
uitools_llm: BaseChatModel | None = None,
) -> CompiledStateGraph:
"""
Creates a supervisor agent that coordinates multiple child agents as tools.
Expand All @@ -106,6 +108,8 @@ def create_supervisor_agent(
llm: The language model instance to use for the supervisor agent.
child_agents: List of child agents read from AIAgentConfig CRDs.
checkpointer: Checkpointer for persisting agent state.
summary_llm: Optional LLM for summarization middleware. Falls back to llm.
uitools_llm: Optional LLM for UI tools middleware. Falls back to llm.

Returns:
A compiled LangGraph StateGraph ready to be invoked.
Expand All @@ -128,8 +132,8 @@ def create_supervisor_agent(
MessagesHistoryMiddleware(),
cancel_human_validation_middleware(),
inject_additional_kwargs_middleware(),
ui_tools_middleware(llm),
SummarizationMiddleware(model=llm, trigger=[("messages", 30), ("tokens", 30000)], keep=("messages", 15)),
ui_tools_middleware(uitools_llm or llm),
SummarizationMiddleware(model=summary_llm or llm, trigger=[("messages", 30), ("tokens", 30000)], keep=("messages", 15)),
],
)

Expand Down
Loading
Loading