diff --git a/app/dependencies.py b/app/dependencies.py
index 3759748f..a43eec00 100644
--- a/app/dependencies.py
+++ b/app/dependencies.py
@@ -3,12 +3,19 @@
def get_llm() -> BaseLanguageModel:
"""
- FastAPI dependency that provides the singleton LLM instance.
-
- This function is used as a dependency injection in FastAPI endpoints
- to access the configured language model.
-
+ FastAPI dependency that provides the default LLM instance.
+
+ Returns:
+ The default language model instance.
+ """
+ return LLMManager.get_instance()
+
+
+def get_uitools_llm() -> BaseLanguageModel:
+ """
+ FastAPI dependency that provides the UI tools LLM instance.
+
Returns:
- The singleton language model instance.
+ The UI tools language model instance (falls back to default if not configured).
"""
- return LLMManager.get_instance()
\ No newline at end of file
+ return LLMManager.get_instance_for_role("uitools")
\ No newline at end of file
diff --git a/app/routers/configuration.py b/app/routers/configuration.py
index 6fdff95a..955d6585 100644
--- a/app/routers/configuration.py
+++ b/app/routers/configuration.py
@@ -396,8 +396,8 @@ async def update_settings(settings: SettingsUpdate, request: Request):
secret.data = secret_data
v1.patch_namespaced_secret(SETTINGS_SECRET_NAME, AGENT_NAMESPACE, secret)
- # Reset LLMManager singleton to force reinitialization for consistency, but as of now we are redeploying the agent...
- LLMManager._instance = None
+ # Reset LLMManager to force reinitialization for consistency, but as of now we are redeploying the agent...
+ LLMManager.reset()
# Re-fetch both resources to return the updated content
updated_cm = v1.read_namespaced_config_map(SETTINGS_CONFIGMAP_NAME, AGENT_NAMESPACE)
diff --git a/app/routers/llm.py b/app/routers/llm.py
index a7099a36..00f49859 100644
--- a/app/routers/llm.py
+++ b/app/routers/llm.py
@@ -5,7 +5,7 @@
from langchain_core.language_models.llms import BaseLanguageModel
from ..services.auth import get_user_id_from_request
-from ..dependencies import get_llm
+from ..dependencies import get_uitools_llm
from ..services.ui_tools.selector import create_ui_tools_selector
router = APIRouter(prefix="/v1/api/complete", tags=["llm"])
@@ -31,7 +31,7 @@ class UIToolsRequest(LLMRequest):
async def complete_ui_tools(
request: Request,
ui_tools_request: UIToolsRequest,
- llm: BaseLanguageModel = Depends(get_llm)
+ llm: BaseLanguageModel = Depends(get_uitools_llm)
) -> JSONResponse:
"""
Select appropriate UI tools based on the provided context using the LLM.
diff --git a/app/routers/websocket.py b/app/routers/websocket.py
index 0b5f43a5..046259e7 100644
--- a/app/routers/websocket.py
+++ b/app/routers/websocket.py
@@ -4,16 +4,14 @@
import json
from datetime import datetime
-from ..dependencies import get_llm
from ..services.agent.factory import NoAgentAvailableError, build_agent
from ..services.agent.supervisor import SupervisorGraph
from dataclasses import dataclass
from fastapi import APIRouter
-from fastapi import WebSocket, WebSocketDisconnect, Depends
+from fastapi import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState
from langgraph.graph.state import CompiledStateGraph
from langfuse.langchain import CallbackHandler
-from langchain_core.language_models.llms import BaseLanguageModel
from langchain_core.messages import HumanMessage
from langgraph.types import Command
@@ -59,7 +57,7 @@ class WebSocketRequest:
@router.websocket("/v1/ws/messages")
@router.websocket("/v1/ws/messages/{thread_id}")
-async def websocket_endpoint(websocket: WebSocket, thread_id: str = None, llm: BaseLanguageModel = Depends(get_llm)):
+async def websocket_endpoint(websocket: WebSocket, thread_id: str = None):
"""
WebSocket endpoint for the agent.
@@ -77,7 +75,7 @@ async def websocket_endpoint(websocket: WebSocket, thread_id: str = None, llm: B
logging.debug("ws/messages connection opened")
try:
- agent, agents_metadata = await build_agent(llm=llm, websocket=websocket)
+ agent, agents_metadata = await build_agent(websocket=websocket)
except NoAgentAvailableError as e:
logging.error(f"Error creating agent: {e}")
await websocket.send_text(f'{json.dumps({"message": str(e)})}')
diff --git a/app/services/agent/child.py b/app/services/agent/child.py
index 4c67f287..53d958bd 100644
--- a/app/services/agent/child.py
+++ b/app/services/agent/child.py
@@ -40,6 +40,8 @@ def create_child_agent(
system_prompt: str,
checkpointer: Checkpointer,
agent_config: AgentConfig,
+ summary_llm: BaseChatModel | None = None,
+ uitools_llm: BaseChatModel | None = None,
) -> CompiledStateGraph:
"""Create and compile a child agent graph using langchain create_agent with middleware.
@@ -57,8 +59,8 @@ def create_child_agent(
identity_preamble_middleware(),
cancel_human_validation_middleware(),
inject_additional_kwargs_middleware(),
- ui_tools_middleware(llm, only_when_direct=True),
- SummarizationMiddleware(model=llm, trigger=[("messages", 30), ("tokens", 30000)], keep=("messages", 15)),
+ ui_tools_middleware(uitools_llm or llm, only_when_direct=True),
+ SummarizationMiddleware(model=summary_llm or llm, trigger=[("messages", 30), ("tokens", 30000)], keep=("messages", 15)),
]
return create_agent(
diff --git a/app/services/agent/factory.py b/app/services/agent/factory.py
index 7fe69d92..3877f033 100644
--- a/app/services/agent/factory.py
+++ b/app/services/agent/factory.py
@@ -12,6 +12,7 @@
from langchain_core.language_models.llms import BaseLanguageModel
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.graph.state import Checkpointer, CompiledStateGraph
+from ..llm import LLMManager
NAMESPACE = "cattle-ai-agent-system"
@@ -20,16 +21,16 @@ class NoAgentAvailableError(Exception):
pass
-async def build_agent(llm: BaseLanguageModel, websocket: WebSocket) -> tuple[CompiledStateGraph | SupervisorGraph, list[dict]]:
+async def build_agent(websocket: WebSocket) -> tuple[CompiledStateGraph | SupervisorGraph, list[dict]]:
"""
Build an agent graph from AIAgentConfig CRDs.
Loads agent configurations and creates either a supervisor (multi-agent) or a
single child agent depending on how many configs are available and successfully
- connect to their MCP servers.
+ connect to their MCP servers. LLM instances are resolved per-role and per-agent
+ via LLMManager.
Args:
- llm: The language model to use for agent reasoning.
websocket: WebSocket connection for authentication context.
Returns:
@@ -47,15 +48,21 @@ async def build_agent(llm: BaseLanguageModel, websocket: WebSocket) -> tuple[Com
logging.info(f"Loaded {len(agent_configs)} agent configuration(s)")
+ summary_llm = LLMManager.get_instance_for_role("summary")
+ uitools_llm = LLMManager.get_instance_for_role("uitools")
+
if len(agent_configs) == 1:
agent_cfg = agent_configs[0]
+ agent_llm = LLMManager.get_instance_for_agent(agent_cfg.name, agent_cfg.llm, agent_cfg.llm_model)
tools = await _load_mcp_tools(agent_cfg, websocket)
- graph = create_child_agent(llm, tools, agent_cfg.system_prompt, checkpointer, agent_cfg)
+ graph = create_child_agent(agent_llm, tools, agent_cfg.system_prompt, checkpointer, agent_cfg,
+ summary_llm=summary_llm, uitools_llm=uitools_llm)
return graph, [{"name": agent_cfg.name, "status": "active"}]
# Multi-agent setup
logging.info(f"Multi-agent setup detected, creating supervisor with {len(agent_configs)} agents.")
- child_agents, agents_metadata = await _build_child_agents(llm, agent_configs, checkpointer, websocket)
+ child_agents, agents_metadata = await _build_child_agents(agent_configs, checkpointer, websocket,
+ summary_llm=summary_llm, uitools_llm=uitools_llm)
if not child_agents:
logging.error("Failed to create any child agents due to MCP connection issues")
@@ -67,7 +74,9 @@ async def build_agent(llm: BaseLanguageModel, websocket: WebSocket) -> tuple[Com
logging.warning("Only one child agent connected successfully. Using it directly instead of a supervisor.")
return child_agents[0].agent, agents_metadata
- graph = create_supervisor_agent(llm, child_agents, checkpointer)
+ supervisor_llm = LLMManager.get_instance_for_role("supervisor")
+ graph = create_supervisor_agent(supervisor_llm, child_agents, checkpointer,
+ summary_llm=summary_llm, uitools_llm=uitools_llm)
supervisor = SupervisorGraph(
graph=graph,
child_agents={ca.config.name: ca.agent for ca in child_agents},
@@ -76,10 +85,11 @@ async def build_agent(llm: BaseLanguageModel, websocket: WebSocket) -> tuple[Com
async def _build_child_agents(
- llm: BaseLanguageModel,
agent_configs: list[AgentConfig],
checkpointer: Checkpointer,
websocket: WebSocket,
+ summary_llm: BaseLanguageModel | None = None,
+ uitools_llm: BaseLanguageModel | None = None,
) -> tuple[list[ChildAgent], list[dict]]:
"""
Attempt to build a child agent for each config, collecting successes and failures.
@@ -92,10 +102,12 @@ async def _build_child_agents(
for agent_cfg in agent_configs:
try:
+ agent_llm = LLMManager.get_instance_for_agent(agent_cfg.name, agent_cfg.llm, agent_cfg.llm_model)
tools = await _load_mcp_tools(agent_cfg, websocket)
child_agents.append(ChildAgent(
config=agent_cfg,
- agent=create_child_agent(llm, tools, agent_cfg.system_prompt, checkpointer, agent_cfg),
+ agent=create_child_agent(agent_llm, tools, agent_cfg.system_prompt, checkpointer, agent_cfg,
+ summary_llm=summary_llm, uitools_llm=uitools_llm),
))
agents_metadata.append({"name": agent_cfg.name, "status": "active"})
except (NoAgentAvailableError, Exception) as e:
diff --git a/app/services/agent/loader.py b/app/services/agent/loader.py
index dcdcda9b..db53ec15 100644
--- a/app/services/agent/loader.py
+++ b/app/services/agent/loader.py
@@ -175,6 +175,8 @@ class AgentConfig(BaseModel):
toolset: Optional[str] = None
human_validation_tools: list[str] = []
ready: bool = False
+ llm: Optional[str] = None
+ llm_model: Optional[str] = None
def _load_k8s_config():
@@ -319,7 +321,9 @@ def _crd_to_agent_config(crd_obj: dict) -> AgentConfig:
ca_bundle_ref=CABundleRef(**spec["caBundleRef"]) if spec.get("caBundleRef") else None,
toolset=spec.get("toolSet", None),
human_validation_tools=human_validation_tools,
- ready=status.get("phase", "Failed") == "Ready"
+ ready=status.get("phase", "Failed") == "Ready",
+ llm=spec.get("llm", None),
+ llm_model=spec.get("llmModel", None),
)
diff --git a/app/services/agent/supervisor.py b/app/services/agent/supervisor.py
index 17bb5700..b0e702f3 100644
--- a/app/services/agent/supervisor.py
+++ b/app/services/agent/supervisor.py
@@ -94,6 +94,8 @@ def create_supervisor_agent(
llm: BaseChatModel,
child_agents: list[ChildAgent],
checkpointer: Checkpointer,
+ summary_llm: BaseChatModel | None = None,
+ uitools_llm: BaseChatModel | None = None,
) -> CompiledStateGraph:
"""
Creates a supervisor agent that coordinates multiple child agents as tools.
@@ -106,6 +108,8 @@ def create_supervisor_agent(
llm: The language model instance to use for the supervisor agent.
child_agents: List of child agents read from AIAgentConfig CRDs.
checkpointer: Checkpointer for persisting agent state.
+ summary_llm: Optional LLM for summarization middleware. Falls back to llm.
+ uitools_llm: Optional LLM for UI tools middleware. Falls back to llm.
Returns:
A compiled LangGraph StateGraph ready to be invoked.
@@ -128,8 +132,8 @@ def create_supervisor_agent(
MessagesHistoryMiddleware(),
cancel_human_validation_middleware(),
inject_additional_kwargs_middleware(),
- ui_tools_middleware(llm),
- SummarizationMiddleware(model=llm, trigger=[("messages", 30), ("tokens", 30000)], keep=("messages", 15)),
+ ui_tools_middleware(uitools_llm or llm),
+ SummarizationMiddleware(model=summary_llm or llm, trigger=[("messages", 30), ("tokens", 30000)], keep=("messages", 15)),
],
)
diff --git a/app/services/llm.py b/app/services/llm.py
index d27d67ee..4fb18ca6 100644
--- a/app/services/llm.py
+++ b/app/services/llm.py
@@ -1,5 +1,6 @@
import os
import logging
+from typing import Optional
from langchain_ollama import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI
@@ -7,60 +8,63 @@
from langchain_core.language_models.llms import BaseLanguageModel
from langchain_aws import ChatBedrockConverse
+VALID_PROVIDERS = ("ollama", "gemini", "openai", "bedrock")
+
+ROLE_ENV_VARS = {
+ "supervisor": ("SUPERVISOR_LLM", "SUPERVISOR_MODEL"),
+ "uitools": ("UITOOLS_LLM", "UITOOLS_MODEL"),
+ "summary": ("SUMMARY_LLM", "SUMMARY_MODEL"),
+}
+
+
class LLMManager:
- """
- Singleton manager for language model instances.
-
- This class ensures that only one instance of the language model is created
- and reused throughout the application, avoiding redundant initializations
- and ensuring consistent model configuration.
- """
- _instance: BaseLanguageModel = None
+ _instances: dict[str, BaseLanguageModel] = {}
+
+ @classmethod
+ def get_instance(cls, key: str = "default") -> BaseLanguageModel:
+ if key not in cls._instances:
+ cls._instances[key] = get_llm()
+ return cls._instances[key]
+
+ @classmethod
+ def get_instance_for_role(cls, role: str) -> BaseLanguageModel:
+ if role in cls._instances:
+ return cls._instances[role]
+ llm = get_llm_for_role(role)
+ cls._instances[role] = llm
+ return llm
+
+ @classmethod
+ def get_instance_for_agent(cls, agent_name: str, llm_provider: Optional[str], llm_model: Optional[str]) -> BaseLanguageModel:
+ key = f"agent:{agent_name}"
+ if key in cls._instances:
+ return cls._instances[key]
+ llm = get_llm_for_agent(llm_provider, llm_model)
+ cls._instances[key] = llm
+ return llm
@classmethod
- def get_instance(cls) -> BaseLanguageModel:
- """
- Retrieves the singleton instance of the language model.
-
- If the instance doesn't exist yet, it initializes it by calling get_llm().
- Subsequent calls return the same instance.
-
- Returns:
- The singleton language model instance.
- """
- if cls._instance is None:
- cls._instance = get_llm()
- logging.info(f"Using model: {cls._instance}")
- return cls._instance
-
-def get_llm() -> BaseLanguageModel:
- """
- Selects and returns a language model instance based on environment variables.
- - If the active LLM or the model is not configured, it raises a ValueError.
- - If LLM mocking is enabled, it configures the connections to the mock server.
-
- Returns:
- An instance of a language model.
-
- Raises:
- ValueError: If the active LLM or the model is not configured.
- """
-
- activeLlm = get_active_llm()
- model = get_llm_model(activeLlm)
-
+ def reset(cls):
+ cls._instances = {}
+
+
+def get_llm(llm_provider: Optional[str] = None, model: Optional[str] = None) -> BaseLanguageModel:
+ if llm_provider is None:
+ llm_provider = get_active_llm()
+ if model is None:
+ model = get_llm_model(llm_provider)
+
llm_mock_enabled = os.environ.get("LLM_MOCK_ENABLED", "false").lower() == "true"
llm_mock_url = os.environ.get("LLM_MOCK_URL", "")
if llm_mock_enabled:
logging.info(f"Connecting to LLM Mock server at {llm_mock_url}")
- if activeLlm == "ollama":
+ if llm_provider == "ollama":
if llm_mock_enabled:
return ChatOllama(model=model, base_url=llm_mock_url)
-
ollama_url = os.environ.get("OLLAMA_URL")
return ChatOllama(model=model, base_url=ollama_url)
- if activeLlm == "gemini":
+ if llm_provider == "gemini":
if llm_mock_enabled:
return ChatGoogleGenerativeAI(
model=model,
@@ -68,55 +72,62 @@ def get_llm() -> BaseLanguageModel:
transport="rest"
)
if model == "gemini-2.5-flash":
- # Disable thinking budget for gemini-2.5-flash to avoid empty responses due to all tokens being used for thinking budget
return ChatGoogleGenerativeAI(model=model, thinking_budget=0)
-
return ChatGoogleGenerativeAI(model=model)
- if activeLlm == "openai":
+ if llm_provider == "openai":
if llm_mock_enabled:
return ChatOpenAI(model=model, base_url=llm_mock_url)
-
openai_url = os.environ.get("OPENAI_URL")
if openai_url:
return ChatOpenAI(model=model, base_url=openai_url)
return ChatOpenAI(model=model)
- if activeLlm == "bedrock":
+ if llm_provider == "bedrock":
if llm_mock_enabled:
os.environ["AWS_ENDPOINT_URL"] = llm_mock_url
return ChatBedrockConverse(model=model)
+ raise ValueError(f"Unsupported LLM provider: {llm_provider}")
+
+
+def get_llm_for_role(role: str) -> BaseLanguageModel:
+ env_llm, env_model = ROLE_ENV_VARS.get(role, (None, None))
+ if not env_llm:
+ raise ValueError(f"Unknown role: {role}")
+
+ provider = os.environ.get(env_llm, "").strip() or None
+ model = os.environ.get(env_model, "").strip() or None
+
+ if provider and provider not in VALID_PROVIDERS:
+ raise ValueError(f"Invalid LLM provider '{provider}' for role '{role}'.")
+
+ if provider and model:
+ return get_llm(llm_provider=provider, model=model)
+ return get_llm()
+
+
+def get_llm_for_agent(llm_provider: Optional[str], llm_model: Optional[str]) -> BaseLanguageModel:
+ provider = (llm_provider or "").strip() or None
+ model = (llm_model or "").strip() or None
+
+ if provider and provider not in VALID_PROVIDERS:
+ raise ValueError(f"Invalid LLM provider '{provider}' for agent.")
+
+ if provider and model:
+ return get_llm(llm_provider=provider, model=model)
+ return get_llm()
+
+
def get_active_llm() -> str:
- """
- Retrieves the active LLM identifier from environment variables.
-
- Returns:
- The active LLM as a string, or None if not set.
- """
llm = os.environ.get("ACTIVE_LLM", "")
-
- if llm not in ["ollama", "gemini", "openai", "bedrock"]:
+ if llm not in VALID_PROVIDERS:
raise ValueError("LLM not configured.")
-
return llm
-def get_llm_model(llm: str) -> str:
- """
- Retrieves the model name from environment variables.
-
- Args:
- llm: The LLM identifier, one of 'ollama', 'gemini', 'openai', 'bedrock'.
-
- Returns:
- The model name as a string.
- """
+def get_llm_model(llm: str) -> str:
model = None
-
if llm:
model = os.environ.get(f"{llm.upper()}_MODEL")
-
if not model:
raise ValueError("LLM Model not configured.")
-
return model
-
diff --git a/chart/agent/templates/ai-agent-deployment.yaml b/chart/agent/templates/ai-agent-deployment.yaml
index 152659ae..51759d9d 100644
--- a/chart/agent/templates/ai-agent-deployment.yaml
+++ b/chart/agent/templates/ai-agent-deployment.yaml
@@ -103,11 +103,41 @@ spec:
name: llm-secret
key: GOOGLE_API_KEY
optional: true
- - name: ACTIVE_LLM
+ - name: SUPERVISOR_LLM
valueFrom:
configMapKeyRef:
name: llm-config
- key: ACTIVE_LLM
+ key: SUPERVISOR_LLM
+ optional: true
+ - name: UITOOLS_LLM
+ valueFrom:
+ configMapKeyRef:
+ name: llm-config
+ key: UITOOLS_LLM
+ optional: true
+ - name: SUMMARY_LLM
+ valueFrom:
+ configMapKeyRef:
+ name: llm-config
+ key: SUMMARY_LLM
+ optional: true
+ - name: SUPERVISOR_MODEL
+ valueFrom:
+ configMapKeyRef:
+ name: llm-config
+ key: SUPERVISOR_MODEL
+ optional: true
+ - name: UITOOLS_MODEL
+ valueFrom:
+ configMapKeyRef:
+ name: llm-config
+ key: UITOOLS_MODEL
+ optional: true
+ - name: SUMMARY_MODEL
+ valueFrom:
+ configMapKeyRef:
+ name: llm-config
+ key: SUMMARY_MODEL
optional: true
- name: LANGFUSE_PUBLIC_KEY
valueFrom:
diff --git a/chart/agent/templates/crds/ai.cattle.io_aiagentconfigs.yaml b/chart/agent/templates/crds/ai.cattle.io_aiagentconfigs.yaml
index 14083219..e7e72531 100644
--- a/chart/agent/templates/crds/ai.cattle.io_aiagentconfigs.yaml
+++ b/chart/agent/templates/crds/ai.cattle.io_aiagentconfigs.yaml
@@ -3,7 +3,7 @@ apiVersion: apiextensions.k8s.io/v1
kind: CustomResourceDefinition
metadata:
annotations:
- controller-gen.kubebuilder.io/version: v0.17.1
+ controller-gen.kubebuilder.io/version: v0.21.0
name: aiagentconfigs.ai.cattle.io
spec:
group: ai.cattle.io
@@ -66,17 +66,19 @@ spec:
description: BuiltIn indicates if this is a built-in agent configuration
type: boolean
caBundleRef:
- description: CABundleRef references a secret key containing a
- PEM-encoded CA certificate bundle to trust when connecting to
- this agent's MCP server
+ description: |-
+ CABundleRef references a secret key containing a PEM-encoded CA
+ certificate bundle to trust when connecting to this agent's MCP server.
+ The CA is scoped to this agent's connection only.
properties:
- name:
- description: Name is the name of the secret in the agent namespace.
- type: string
key:
default: ca.crt
- description: Key is the key inside the secret that holds the
- data. Defaults to "ca.crt".
+ description: |-
+ Key is the key inside the secret that holds the data.
+ Defaults to "ca.crt".
+ type: string
+ name:
+ description: Name is the name of the secret in the agent namespace.
type: string
required:
- name
@@ -95,6 +97,13 @@ spec:
items:
type: string
type: array
+ llm:
+ description: LLM specifies the LLM provider to be used by the agent
+ type: string
+ llmModel:
+ description: LLMModel specifies the language model to be used by the
+ agent
+ type: string
mcpURL:
description: MCPURL is the Model Context Protocol server URL
type: string
diff --git a/chart/agent/templates/llm-config.yaml b/chart/agent/templates/llm-config.yaml
index ef2bb704..4b250127 100644
--- a/chart/agent/templates/llm-config.yaml
+++ b/chart/agent/templates/llm-config.yaml
@@ -9,3 +9,9 @@ data:
OPENAI_MODEL: "{{ .Values.openaiLlmModel }}"
BEDROCK_MODEL: "{{ .Values.bedrockLlmModel }}"
ACTIVE_LLM: "{{ .Values.activeLlm }}"
+ SUPERVISOR_LLM: "{{ .Values.supervisorLlm }}"
+ SUPERVISOR_MODEL: "{{ .Values.supervisorModel }}"
+ UITOOLS_LLM: "{{ .Values.uitoolsLlm }}"
+ UITOOLS_MODEL: "{{ .Values.uitoolsModel }}"
+ SUMMARY_LLM: "{{ .Values.summaryLlm }}"
+ SUMMARY_MODEL: "{{ .Values.summaryModel }}"
diff --git a/chart/agent/values.yaml b/chart/agent/values.yaml
index 461b1300..3a3ff195 100644
--- a/chart/agent/values.yaml
+++ b/chart/agent/values.yaml
@@ -26,6 +26,12 @@ awsBedrock:
bearerToken:
region:
activeLlm: # ollama, gemini, openai, bedrock
+supervisorLlm: ""
+supervisorModel: ""
+uitoolsLlm: ""
+uitoolsModel: ""
+summaryLlm: ""
+summaryModel: ""
# Enable RAG with embedded rancher documentation
rag:
enabled: false
diff --git a/crd-generation/api/v1alpha1/aiagentconfig_types.go b/crd-generation/api/v1alpha1/aiagentconfig_types.go
index c05e3e90..6884ff45 100644
--- a/crd-generation/api/v1alpha1/aiagentconfig_types.go
+++ b/crd-generation/api/v1alpha1/aiagentconfig_types.go
@@ -53,6 +53,14 @@ type AIAgentConfigSpec struct {
// The CA is scoped to this agent's connection only.
// +optional
CABundleRef *SecretKeyRef `json:"caBundleRef,omitempty"`
+
+ // LLMModel specifies the language model to be used by the agent
+ // +optional
+ LLMModel string `json:"llmModel,omitempty"`
+
+ // LLM specifies the LLM provider to be used by the agent
+ // +optional
+ LLM string `json:"llm,omitempty"`
}
// SecretKeyRef identifies a key within a Kubernetes secret.
diff --git a/tests/integration/test_multi_agent.py b/tests/integration/test_multi_agent.py
index 8b8a41e9..f4bbe7c0 100644
--- a/tests/integration/test_multi_agent.py
+++ b/tests/integration/test_multi_agent.py
@@ -258,7 +258,7 @@ def test_single_prompt():
fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses)
fake_llm.all_calls = []
- LLMManager._instance = fake_llm
+ LLMManager._instances["default"] = fake_llm
try:
with client.websocket_connect("/v1/ws/messages") as websocket:
@@ -316,7 +316,7 @@ def test_single_prompt():
"Third call should include tool result from math-agent"
finally:
- LLMManager._instance = None
+ LLMManager.reset()
def test_multiple_prompts():
@@ -355,7 +355,7 @@ def test_multiple_prompts():
fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses)
fake_llm.all_calls = []
- LLMManager._instance = fake_llm
+ LLMManager._instances["default"] = fake_llm
try:
with client.websocket_connect("/v1/ws/messages") as websocket:
@@ -432,7 +432,7 @@ def test_multiple_prompts():
assert isinstance(sixth_call[7], ToolMessage) and sixth_call[7].content == fake_llm_response_2
finally:
- LLMManager._instance = None
+ LLMManager.reset()
def test_delegate_to_child_agent_with_tool():
@@ -461,7 +461,7 @@ def test_delegate_to_child_agent_with_tool():
fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses)
fake_llm.all_calls = []
- LLMManager._instance = fake_llm
+ LLMManager._instances["default"] = fake_llm
try:
with client.websocket_connect("/v1/ws/messages") as websocket:
@@ -505,7 +505,7 @@ def test_delegate_to_child_agent_with_tool():
assert fake_llm.all_calls[3][0] == SystemMessage(content=SUPERVISOR_PROMPT)
finally:
- LLMManager._instance = None
+ LLMManager.reset()
def test_delegate_to_child_agent_with_ui_tools():
@@ -547,7 +547,7 @@ def test_delegate_to_child_agent_with_ui_tools():
]
fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses)
- LLMManager._instance = fake_llm
+ LLMManager._instances["default"] = fake_llm
try:
# Create agent configs with UI tools enabled
@@ -632,6 +632,6 @@ def test_delegate_to_child_agent_with_ui_tools():
mock_load_configmap.assert_called()
finally:
- LLMManager._instance = None
+ LLMManager.reset()
diff --git a/tests/integration/test_single_agent.py b/tests/integration/test_single_agent.py
index 4303569a..1d23cf24 100644
--- a/tests/integration/test_single_agent.py
+++ b/tests/integration/test_single_agent.py
@@ -98,7 +98,7 @@ def test_websocket_single_prompt():
fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses)
fake_llm.all_calls = [] # Reset call tracking
- LLMManager._instance = fake_llm
+ LLMManager._instances["default"] = fake_llm
try:
messages = []
@@ -121,7 +121,7 @@ def test_websocket_single_prompt():
SystemMessage(content=IDENTITY_PREAMBLE),
], "First call should have system prompt and user message"
finally:
- LLMManager._instance = None
+ LLMManager.reset()
def test_websocket_multiple_prompts():
"""Tests multiple prompt-response interactions in sequence."""
@@ -137,7 +137,7 @@ def test_websocket_multiple_prompts():
fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses)
fake_llm.all_calls = [] # Reset call tracking
- LLMManager._instance = fake_llm
+ LLMManager._instances["default"] = fake_llm
try:
messages = []
@@ -167,7 +167,7 @@ def test_websocket_multiple_prompts():
HumanMessage(content="fake prompt 2"),
], "Second call should include conversation history"
finally:
- LLMManager._instance = None
+ LLMManager.reset()
def test_websocket_tool_call():
"""Tests agent interaction with tool calling."""
@@ -187,7 +187,7 @@ def test_websocket_tool_call():
fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses)
fake_llm.all_calls = [] # Reset call tracking
- LLMManager._instance = fake_llm
+ LLMManager._instances["default"] = fake_llm
try:
messages = []
@@ -217,7 +217,7 @@ def test_websocket_tool_call():
assert isinstance(second_call[3], AIMessage) and second_call[3].tool_calls[0]["name"] == "add", "Second call should have AI message with tool call"
assert isinstance(second_call[4], ToolMessage) and second_call[4].content == "sum is 9", "Second call should have tool result"
finally:
- LLMManager._instance = None
+ LLMManager.reset()
def test_conversation_history():
"""Tests that conversation history is maintained across multiple prompts."""
@@ -252,7 +252,7 @@ def test_conversation_history():
fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses)
fake_llm.all_calls = [] # Reset call tracking
- LLMManager._instance = fake_llm
+ LLMManager._instances["default"] = fake_llm
try:
messages = []
@@ -326,7 +326,7 @@ def test_conversation_history():
], "Fifth call should include full conversation history"
finally:
- LLMManager._instance = None
+ LLMManager.reset()
def test_websocket_with_ui_tools():
@@ -358,7 +358,7 @@ def test_websocket_with_ui_tools():
fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses)
fake_llm.all_calls = []
- LLMManager._instance = fake_llm
+ LLMManager._instances["default"] = fake_llm
try:
agent_config = AgentConfig(
@@ -434,4 +434,4 @@ def test_websocket_with_ui_tools():
mock_load_configmap.assert_called(), "UI tools config should be loaded from ConfigMap"
finally:
- LLMManager._instance = None
+ LLMManager.reset()
diff --git a/tests/unit/routers/test_websocket.py b/tests/unit/routers/test_websocket.py
index fc294c07..374ff5d1 100644
--- a/tests/unit/routers/test_websocket.py
+++ b/tests/unit/routers/test_websocket.py
@@ -79,12 +79,10 @@ def mock_dependencies():
@pytest.mark.asyncio
async def test_websocket_endpoint(mock_dependencies):
mock_ws = MockWebSocket(messages=["test message"])
- mock_llm = MagicMock()
-
- await websocket_endpoint(mock_ws, None, mock_llm)
+ await websocket_endpoint(mock_ws, None)
assert mock_ws.accepted
- mock_dependencies["build_agent"].assert_called_once_with(llm=mock_llm, websocket=mock_ws)
+ mock_dependencies["build_agent"].assert_called_once_with(websocket=mock_ws)
mock_dependencies["call_agent"].assert_awaited_once()
call_kwargs = mock_dependencies["call_agent"].call_args.kwargs
@@ -97,9 +95,7 @@ async def test_websocket_endpoint_context_message(mock_dependencies):
mock_ws = MockWebSocket(messages=[
'{"prompt": "show all pods", "context": {"namespace": "default", "cluster": "local"}}'
])
- mock_llm = MagicMock()
-
- await websocket_endpoint(mock_ws, None, mock_llm)
+ await websocket_endpoint(mock_ws, None)
mock_dependencies["build_agent"].assert_called_once()
mock_dependencies["call_agent"].assert_awaited_once()
@@ -115,9 +111,7 @@ async def test_websocket_endpoint_context_message(mock_dependencies):
@pytest.mark.asyncio
async def test_websocket_endpoint_sends_chat_metadata(mock_dependencies):
mock_ws = MockWebSocket(messages=["hello"])
- mock_llm = MagicMock()
-
- await websocket_endpoint(mock_ws, None, mock_llm)
+ await websocket_endpoint(mock_ws, None)
# First message sent should be chat metadata
metadata_msg = mock_ws._send_queue[0]
@@ -131,9 +125,7 @@ async def test_websocket_endpoint_no_agent_available(mock_dependencies):
from app.services.agent.factory import NoAgentAvailableError
mock_dependencies["build_agent"].side_effect = NoAgentAvailableError("No agents")
mock_ws = MockWebSocket(messages=[])
- mock_llm = MagicMock()
-
- await websocket_endpoint(mock_ws, None, mock_llm)
+ await websocket_endpoint(mock_ws, None)
assert mock_ws.accepted
assert mock_ws.closed
@@ -145,9 +137,7 @@ async def test_websocket_endpoint_no_agent_available(mock_dependencies):
@pytest.mark.asyncio
async def test_websocket_endpoint_with_thread_id(mock_dependencies):
mock_ws = MockWebSocket(messages=["hello"])
- mock_llm = MagicMock()
-
- await websocket_endpoint(mock_ws, "custom-thread-id", mock_llm)
+ await websocket_endpoint(mock_ws, "custom-thread-id")
assert mock_ws.accepted
metadata_msg = mock_ws._send_queue[0]
@@ -159,9 +149,7 @@ async def test_websocket_endpoint_error_during_processing(mock_dependencies):
"""Test that errors during message processing are sent to the client."""
mock_dependencies["call_agent"].side_effect = Exception("Something went wrong")
mock_ws = MockWebSocket(messages=["test", "second"])
- mock_llm = MagicMock()
-
- await websocket_endpoint(mock_ws, None, mock_llm)
+ await websocket_endpoint(mock_ws, None)
# Should have sent error and end message
error_messages = [m for m in mock_ws._send_queue if "" in m]
diff --git a/tests/unit/services/agent/test_factory.py b/tests/unit/services/agent/test_factory.py
index 5011db63..bae060b7 100644
--- a/tests/unit/services/agent/test_factory.py
+++ b/tests/unit/services/agent/test_factory.py
@@ -23,102 +23,112 @@
# ============================================================================
@pytest.mark.asyncio
+@patch('app.services.agent.factory.LLMManager')
@patch('app.services.agent.factory.load_agent_configs')
@patch('app.services.agent.factory._load_mcp_tools')
@patch('app.services.agent.factory.create_child_agent')
-async def test_create_agent_single_agent(mock_create_child, mock_load_tools, mock_load_configs):
+async def test_create_agent_single_agent(mock_create_child, mock_load_tools, mock_load_configs, mock_llm_manager):
"""Verify build_agent creates a child agent when one config is available."""
# Setup mocks
mock_llm = MagicMock()
+ mock_summary_llm = MagicMock()
+ mock_uitools_llm = MagicMock()
+ mock_llm_manager.get_instance_for_agent.return_value = mock_llm
+ mock_llm_manager.get_instance_for_role.side_effect = lambda role: mock_summary_llm if role == "summary" else mock_uitools_llm
+
mock_websocket = MagicMock()
mock_memory_manager = MagicMock()
mock_checkpointer = MagicMock()
mock_memory_manager.get_checkpointer.return_value = mock_checkpointer
mock_websocket.app.memory_manager = mock_memory_manager
-
+
mock_agent_config = MagicMock()
mock_agent_config.name = "RancherAgent"
mock_agent_config.system_prompt = "Test prompt"
mock_load_configs.return_value = [mock_agent_config]
-
+
mock_tools = [MagicMock()]
mock_load_tools.return_value = mock_tools
mock_agent = MagicMock()
mock_create_child.return_value = mock_agent
-
+
# Execute
- result = await build_agent(mock_llm, mock_websocket)
-
+ result = await build_agent(mock_websocket)
+
# Verify
assert result[0] == mock_agent
assert result[1] == [{"name": "RancherAgent", "status": "active"}]
mock_load_tools.assert_called_once_with(mock_agent_config, mock_websocket)
mock_create_child.assert_called_once_with(
- mock_llm, mock_tools, "Test prompt", mock_checkpointer, mock_agent_config
+ mock_llm, mock_tools, "Test prompt", mock_checkpointer, mock_agent_config,
+ summary_llm=mock_summary_llm, uitools_llm=mock_uitools_llm
)
@pytest.mark.asyncio
+@patch('app.services.agent.factory.LLMManager')
@patch('app.services.agent.factory.load_agent_configs')
@patch('app.services.agent.factory.create_supervisor_agent')
@patch('app.services.agent.factory.create_child_agent')
@patch('app.services.agent.factory.create_mcp_client')
@patch('app.services.agent.factory._update_agent_status')
-async def test_build_agent_three_agents(mock_update_status, mock_create_client, mock_create_child, mock_create_parent, mock_load_configs):
+async def test_build_agent_three_agents(mock_update_status, mock_create_client, mock_create_child, mock_create_parent, mock_load_configs, mock_llm_manager):
"""Verify build_agent creates a supervisor agent when three configs are available."""
# Setup mocks
mock_llm = MagicMock()
+ mock_llm_manager.get_instance_for_agent.return_value = mock_llm
+ mock_llm_manager.get_instance_for_role.return_value = mock_llm
+
mock_websocket = MagicMock()
mock_memory_manager = MagicMock()
mock_checkpointer = MagicMock()
mock_memory_manager.get_checkpointer.return_value = mock_checkpointer
mock_websocket.app.memory_manager = mock_memory_manager
-
+
mock_config1 = MagicMock()
mock_config1.name = "RancherAgent"
mock_config1.description = "Rancher core agent"
mock_config1.system_prompt = "Prompt 1"
-
+
mock_config2 = MagicMock()
mock_config2.name = "FleetAgent"
mock_config2.description = "Fleet agent"
mock_config2.system_prompt = "Prompt 2"
-
+
mock_config3 = MagicMock()
mock_config3.name = "HarvesterAgent"
mock_config3.description = "Harvester agent"
mock_config3.system_prompt = "Prompt 3"
-
+
mock_load_configs.return_value = [mock_config1, mock_config2, mock_config3]
-
+
# Mock MCP client
mock_client_instance = MagicMock()
mock_tools = [MagicMock()]
mock_client_instance.get_tools = AsyncMock(return_value=mock_tools)
mock_create_client.return_value = mock_client_instance
-
+
mock_parent_agent = MagicMock()
mock_create_parent.return_value = mock_parent_agent
mock_create_child.return_value = MagicMock()
-
+
# Execute
- result = await build_agent(mock_llm, mock_websocket)
-
+ result = await build_agent(mock_websocket)
+
# Verify - build_agent wraps multi-agent result in SupervisorGraph
from app.services.agent.supervisor import SupervisorGraph
assert isinstance(result[0], SupervisorGraph)
assert result[0]._graph == mock_parent_agent
mock_create_parent.assert_called_once()
-
+
# Verify parent was called with correct subagents
call_args = mock_create_parent.call_args
- assert call_args[0][0] == mock_llm
subagents = call_args[0][1]
assert len(subagents) == 3
assert subagents[0].config.name == "RancherAgent"
assert subagents[1].config.name == "FleetAgent"
assert subagents[2].config.name == "HarvesterAgent"
-
+
# Verify metadata includes all agents
metadata = result[1]
assert len(metadata) == 3
@@ -126,79 +136,79 @@ async def test_build_agent_three_agents(mock_update_status, mock_create_client,
@pytest.mark.asyncio
+@patch('app.services.agent.factory.LLMManager')
@patch('app.services.agent.factory.load_agent_configs')
@patch('app.services.agent.factory.create_supervisor_agent')
@patch('app.services.agent.factory.create_child_agent')
@patch('app.services.agent.factory.create_mcp_client')
@patch('app.services.agent.factory._update_agent_status')
-async def test_build_agent_filters_tools_by_toolset(mock_update_status, mock_create_client, mock_create_child, mock_create_parent, mock_load_configs):
+async def test_build_agent_filters_tools_by_toolset(mock_update_status, mock_create_client, mock_create_child, mock_create_parent, mock_load_configs, mock_llm_manager):
"""Verify build_agent filters tools based on toolset configuration."""
# Setup mocks
mock_llm = MagicMock()
+ mock_llm_manager.get_instance_for_agent.return_value = mock_llm
+ mock_llm_manager.get_instance_for_role.return_value = mock_llm
+
mock_websocket = MagicMock()
mock_memory_manager = MagicMock()
mock_checkpointer = MagicMock()
mock_memory_manager.get_checkpointer.return_value = mock_checkpointer
mock_websocket.app.memory_manager = mock_memory_manager
-
+
mock_config1 = MagicMock()
mock_config1.name = "RancherAgent"
mock_config1.description = "Rancher agent with specific toolset"
mock_config1.system_prompt = "Prompt 1"
- mock_config1.toolset = "rancher-core" # Specify toolset filter
-
+ mock_config1.toolset = "rancher-core"
+
mock_config2 = MagicMock()
mock_config2.name = "FleetAgent"
mock_config2.description = "Fleet agent without toolset filter"
mock_config2.system_prompt = "Prompt 2"
- mock_config2.toolset = None # No toolset filter
-
+ mock_config2.toolset = None
+
mock_load_configs.return_value = [mock_config1, mock_config2]
-
+
# Mock MCP client with tools that have different toolsets
- # Create mock tools with metadata
tool_rancher_core = MagicMock()
tool_rancher_core.name = "rancher_tool"
tool_rancher_core.metadata = {"_meta": {"toolset": "rancher-core"}}
-
+
tool_rancher_extensions = MagicMock()
tool_rancher_extensions.name = "extensions_tool"
tool_rancher_extensions.metadata = {"_meta": {"toolset": "rancher-extensions"}}
-
+
tool_fleet = MagicMock()
tool_fleet.name = "fleet_tool"
tool_fleet.metadata = {"_meta": {"toolset": "fleet"}}
-
+
tool_no_toolset = MagicMock()
tool_no_toolset.name = "generic_tool"
tool_no_toolset.metadata = {}
-
+
all_tools = [tool_rancher_core, tool_rancher_extensions, tool_fleet, tool_no_toolset]
-
+
mock_client_instance = MagicMock()
mock_client_instance.get_tools = AsyncMock(return_value=all_tools)
mock_create_client.return_value = mock_client_instance
-
+
mock_parent_agent = MagicMock()
mock_create_parent.return_value = mock_parent_agent
mock_create_child.return_value = MagicMock()
-
+
# Execute
- result = await build_agent(mock_llm, mock_websocket)
-
+ result = await build_agent(mock_websocket)
+
# Verify - build_agent wraps multi-agent result in SupervisorGraph
from app.services.agent.supervisor import SupervisorGraph
assert isinstance(result[0], SupervisorGraph)
-
+
# Verify subagents passed to create_supervisor_agent have correct tools
call_args = mock_create_parent.call_args
subagents = call_args[0][1]
assert len(subagents) == 2
-
- # First subagent (RancherAgent) should have only rancher-core tools
+
assert subagents[0].config.name == "RancherAgent"
-
- # Second subagent (FleetAgent) should have all tools (no toolset filter)
assert subagents[1].config.name == "FleetAgent"
# Verify create_child_agent was called with filtered tools
@@ -214,63 +224,65 @@ async def test_build_agent_filters_tools_by_toolset(mock_update_status, mock_cre
@pytest.mark.asyncio
+@patch('app.services.agent.factory.LLMManager')
@patch('app.services.agent.factory.load_agent_configs')
@patch('app.services.agent.factory.create_supervisor_agent')
@patch('app.services.agent.factory.create_child_agent')
@patch('app.services.agent.factory.create_mcp_client')
@patch('app.services.agent.factory._update_agent_status')
-async def test_build_agent_one_fails_mcp_connection(mock_update_status, mock_create_client, mock_create_child, mock_create_parent, mock_load_configs):
+async def test_build_agent_one_fails_mcp_connection(mock_update_status, mock_create_client, mock_create_child, mock_create_parent, mock_load_configs, mock_llm_manager):
"""Verify build_agent handles MCP connection failure for one agent and continues with others."""
# Setup mocks
mock_llm = MagicMock()
+ mock_llm_manager.get_instance_for_agent.return_value = mock_llm
+ mock_llm_manager.get_instance_for_role.return_value = mock_llm
+
mock_websocket = MagicMock()
mock_memory_manager = MagicMock()
mock_checkpointer = MagicMock()
mock_memory_manager.get_checkpointer.return_value = mock_checkpointer
mock_websocket.app.memory_manager = mock_memory_manager
-
+
mock_config1 = MagicMock()
mock_config1.name = "Agent1"
mock_config1.description = "First agent"
mock_config1.system_prompt = "Prompt 1"
mock_config1.ready = False
-
+
mock_config2 = MagicMock()
mock_config2.name = "Agent2"
mock_config2.description = "Second agent"
mock_config2.system_prompt = "Prompt 2"
mock_config2.ready = False
-
+
mock_config3 = MagicMock()
mock_config3.name = "Agent3"
mock_config3.description = "Third agent"
mock_config3.system_prompt = "Prompt 3"
mock_config3.ready = False
-
+
mock_load_configs.return_value = [mock_config1, mock_config2, mock_config3]
-
+
# Mock MCP client - first one fails, others succeed
- # Create three different client instances
mock_client_fail = MagicMock()
mock_client_fail.get_tools = AsyncMock(side_effect=Exception("Connection refused: invalid MCP URL"))
-
+
mock_client_success1 = MagicMock()
mock_tools = [MagicMock()]
mock_client_success1.get_tools = AsyncMock(return_value=mock_tools)
-
+
mock_client_success2 = MagicMock()
mock_client_success2.get_tools = AsyncMock(return_value=mock_tools)
-
- # Return different clients on each call
+
mock_create_client.side_effect = [mock_client_fail, mock_client_success1, mock_client_success2]
-
+
mock_parent_agent = MagicMock()
mock_create_parent.return_value = mock_parent_agent
mock_create_child.return_value = MagicMock()
-
+
# Execute
- result = await build_agent(mock_llm, mock_websocket)
-
+ result = await build_agent(mock_websocket)
+
# Should return supervisor since 2 agents succeeded
from app.services.agent.supervisor import SupervisorGraph
assert isinstance(result[0], SupervisorGraph)
@@ -278,12 +290,11 @@ async def test_build_agent_one_fails_mcp_connection(mock_update_status, mock_cre
# Verify parent was called with correct subagents
call_args = mock_create_parent.call_args
- assert call_args[0][0] == mock_llm
subagents = call_args[0][1]
assert len(subagents) == 2
assert subagents[0].config.name == "Agent2"
assert subagents[1].config.name == "Agent3"
-
+
# Verify metadata includes all three agents with correct status
metadata = result[1]
assert len(metadata) == 3
@@ -294,51 +305,54 @@ async def test_build_agent_one_fails_mcp_connection(mock_update_status, mock_cre
assert metadata[1]["name"] == "Agent2"
assert metadata[2]["status"] == "active"
assert metadata[2]["name"] == "Agent3"
-
-
+
# Verify status update was called for the failed agent
update_calls = [call for call in mock_update_status.call_args_list if call[0][1] == False]
assert len(update_calls) > 0
@pytest.mark.asyncio
+@patch('app.services.agent.factory.LLMManager')
@patch('app.services.agent.factory.load_agent_configs')
@patch('app.services.agent.factory.create_mcp_client')
@patch('app.services.agent.factory._update_agent_status')
-async def test_build_agent_all_fail_mcp_connection(mock_update_status, mock_create_client, mock_load_configs):
+async def test_build_agent_all_fail_mcp_connection(mock_update_status, mock_create_client, mock_load_configs, mock_llm_manager):
"""Verify build_agent raises NoAgentAvailableError when all agents fail MCP connection."""
# Setup mocks
mock_llm = MagicMock()
+ mock_llm_manager.get_instance_for_agent.return_value = mock_llm
+ mock_llm_manager.get_instance_for_role.return_value = mock_llm
+
mock_websocket = MagicMock()
mock_memory_manager = MagicMock()
mock_checkpointer = MagicMock()
mock_memory_manager.get_checkpointer.return_value = mock_checkpointer
mock_websocket.app.memory_manager = mock_memory_manager
-
+
mock_config1 = MagicMock()
mock_config1.name = "Agent1"
mock_config1.description = "First agent"
mock_config1.ready = False
-
+
mock_config2 = MagicMock()
mock_config2.name = "Agent2"
mock_config2.description = "Second agent"
mock_config2.ready = False
-
+
mock_load_configs.return_value = [mock_config1, mock_config2]
-
+
# Mock MCP client - all fail
mock_client_fail = MagicMock()
mock_client_fail.get_tools = AsyncMock(side_effect=Exception("Connection refused: invalid MCP URL"))
-
+
mock_create_client.return_value = mock_client_fail
-
+
# Execute and verify exception
with pytest.raises(NoAgentAvailableError) as exc_info:
- await build_agent(mock_llm, mock_websocket)
-
+ await build_agent(mock_websocket)
+
assert "No agents could be created" in str(exc_info.value)
-
+
# Verify status update was called for both failed agents
assert mock_update_status.call_count >= 2
@@ -347,16 +361,15 @@ async def test_build_agent_all_fail_mcp_connection(mock_update_status, mock_crea
@patch('app.services.agent.factory.load_agent_configs')
async def test_build_agent_no_configs_raises_error(mock_load_configs):
"""Verify build_agent raises NoAgentAvailableError when no configs are available."""
- mock_llm = MagicMock()
mock_websocket = MagicMock()
mock_memory_manager = MagicMock()
mock_websocket.app.memory_manager = mock_memory_manager
-
+
mock_load_configs.return_value = []
-
+
with pytest.raises(NoAgentAvailableError) as exc_info:
- await build_agent(mock_llm, mock_websocket)
-
+ await build_agent(mock_websocket)
+
assert "No agent configurations available" in str(exc_info.value)