diff --git a/app/dependencies.py b/app/dependencies.py index 3759748f..a43eec00 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -3,12 +3,19 @@ def get_llm() -> BaseLanguageModel: """ - FastAPI dependency that provides the singleton LLM instance. - - This function is used as a dependency injection in FastAPI endpoints - to access the configured language model. - + FastAPI dependency that provides the default LLM instance. + + Returns: + The default language model instance. + """ + return LLMManager.get_instance() + + +def get_uitools_llm() -> BaseLanguageModel: + """ + FastAPI dependency that provides the UI tools LLM instance. + Returns: - The singleton language model instance. + The UI tools language model instance (falls back to default if not configured). """ - return LLMManager.get_instance() \ No newline at end of file + return LLMManager.get_instance_for_role("uitools") \ No newline at end of file diff --git a/app/routers/configuration.py b/app/routers/configuration.py index 6fdff95a..955d6585 100644 --- a/app/routers/configuration.py +++ b/app/routers/configuration.py @@ -396,8 +396,8 @@ async def update_settings(settings: SettingsUpdate, request: Request): secret.data = secret_data v1.patch_namespaced_secret(SETTINGS_SECRET_NAME, AGENT_NAMESPACE, secret) - # Reset LLMManager singleton to force reinitialization for consistency, but as of now we are redeploying the agent... - LLMManager._instance = None + # Reset LLMManager to force reinitialization for consistency, but as of now we are redeploying the agent... + LLMManager.reset() # Re-fetch both resources to return the updated content updated_cm = v1.read_namespaced_config_map(SETTINGS_CONFIGMAP_NAME, AGENT_NAMESPACE) diff --git a/app/routers/llm.py b/app/routers/llm.py index a7099a36..00f49859 100644 --- a/app/routers/llm.py +++ b/app/routers/llm.py @@ -5,7 +5,7 @@ from langchain_core.language_models.llms import BaseLanguageModel from ..services.auth import get_user_id_from_request -from ..dependencies import get_llm +from ..dependencies import get_uitools_llm from ..services.ui_tools.selector import create_ui_tools_selector router = APIRouter(prefix="/v1/api/complete", tags=["llm"]) @@ -31,7 +31,7 @@ class UIToolsRequest(LLMRequest): async def complete_ui_tools( request: Request, ui_tools_request: UIToolsRequest, - llm: BaseLanguageModel = Depends(get_llm) + llm: BaseLanguageModel = Depends(get_uitools_llm) ) -> JSONResponse: """ Select appropriate UI tools based on the provided context using the LLM. diff --git a/app/routers/websocket.py b/app/routers/websocket.py index 0b5f43a5..046259e7 100644 --- a/app/routers/websocket.py +++ b/app/routers/websocket.py @@ -4,16 +4,14 @@ import json from datetime import datetime -from ..dependencies import get_llm from ..services.agent.factory import NoAgentAvailableError, build_agent from ..services.agent.supervisor import SupervisorGraph from dataclasses import dataclass from fastapi import APIRouter -from fastapi import WebSocket, WebSocketDisconnect, Depends +from fastapi import WebSocket, WebSocketDisconnect from starlette.websockets import WebSocketState from langgraph.graph.state import CompiledStateGraph from langfuse.langchain import CallbackHandler -from langchain_core.language_models.llms import BaseLanguageModel from langchain_core.messages import HumanMessage from langgraph.types import Command @@ -59,7 +57,7 @@ class WebSocketRequest: @router.websocket("/v1/ws/messages") @router.websocket("/v1/ws/messages/{thread_id}") -async def websocket_endpoint(websocket: WebSocket, thread_id: str = None, llm: BaseLanguageModel = Depends(get_llm)): +async def websocket_endpoint(websocket: WebSocket, thread_id: str = None): """ WebSocket endpoint for the agent. @@ -77,7 +75,7 @@ async def websocket_endpoint(websocket: WebSocket, thread_id: str = None, llm: B logging.debug("ws/messages connection opened") try: - agent, agents_metadata = await build_agent(llm=llm, websocket=websocket) + agent, agents_metadata = await build_agent(websocket=websocket) except NoAgentAvailableError as e: logging.error(f"Error creating agent: {e}") await websocket.send_text(f'{json.dumps({"message": str(e)})}') diff --git a/app/services/agent/child.py b/app/services/agent/child.py index 4c67f287..53d958bd 100644 --- a/app/services/agent/child.py +++ b/app/services/agent/child.py @@ -40,6 +40,8 @@ def create_child_agent( system_prompt: str, checkpointer: Checkpointer, agent_config: AgentConfig, + summary_llm: BaseChatModel | None = None, + uitools_llm: BaseChatModel | None = None, ) -> CompiledStateGraph: """Create and compile a child agent graph using langchain create_agent with middleware. @@ -57,8 +59,8 @@ def create_child_agent( identity_preamble_middleware(), cancel_human_validation_middleware(), inject_additional_kwargs_middleware(), - ui_tools_middleware(llm, only_when_direct=True), - SummarizationMiddleware(model=llm, trigger=[("messages", 30), ("tokens", 30000)], keep=("messages", 15)), + ui_tools_middleware(uitools_llm or llm, only_when_direct=True), + SummarizationMiddleware(model=summary_llm or llm, trigger=[("messages", 30), ("tokens", 30000)], keep=("messages", 15)), ] return create_agent( diff --git a/app/services/agent/factory.py b/app/services/agent/factory.py index 7fe69d92..3877f033 100644 --- a/app/services/agent/factory.py +++ b/app/services/agent/factory.py @@ -12,6 +12,7 @@ from langchain_core.language_models.llms import BaseLanguageModel from langchain_mcp_adapters.client import MultiServerMCPClient from langgraph.graph.state import Checkpointer, CompiledStateGraph +from ..llm import LLMManager NAMESPACE = "cattle-ai-agent-system" @@ -20,16 +21,16 @@ class NoAgentAvailableError(Exception): pass -async def build_agent(llm: BaseLanguageModel, websocket: WebSocket) -> tuple[CompiledStateGraph | SupervisorGraph, list[dict]]: +async def build_agent(websocket: WebSocket) -> tuple[CompiledStateGraph | SupervisorGraph, list[dict]]: """ Build an agent graph from AIAgentConfig CRDs. Loads agent configurations and creates either a supervisor (multi-agent) or a single child agent depending on how many configs are available and successfully - connect to their MCP servers. + connect to their MCP servers. LLM instances are resolved per-role and per-agent + via LLMManager. Args: - llm: The language model to use for agent reasoning. websocket: WebSocket connection for authentication context. Returns: @@ -47,15 +48,21 @@ async def build_agent(llm: BaseLanguageModel, websocket: WebSocket) -> tuple[Com logging.info(f"Loaded {len(agent_configs)} agent configuration(s)") + summary_llm = LLMManager.get_instance_for_role("summary") + uitools_llm = LLMManager.get_instance_for_role("uitools") + if len(agent_configs) == 1: agent_cfg = agent_configs[0] + agent_llm = LLMManager.get_instance_for_agent(agent_cfg.name, agent_cfg.llm, agent_cfg.llm_model) tools = await _load_mcp_tools(agent_cfg, websocket) - graph = create_child_agent(llm, tools, agent_cfg.system_prompt, checkpointer, agent_cfg) + graph = create_child_agent(agent_llm, tools, agent_cfg.system_prompt, checkpointer, agent_cfg, + summary_llm=summary_llm, uitools_llm=uitools_llm) return graph, [{"name": agent_cfg.name, "status": "active"}] # Multi-agent setup logging.info(f"Multi-agent setup detected, creating supervisor with {len(agent_configs)} agents.") - child_agents, agents_metadata = await _build_child_agents(llm, agent_configs, checkpointer, websocket) + child_agents, agents_metadata = await _build_child_agents(agent_configs, checkpointer, websocket, + summary_llm=summary_llm, uitools_llm=uitools_llm) if not child_agents: logging.error("Failed to create any child agents due to MCP connection issues") @@ -67,7 +74,9 @@ async def build_agent(llm: BaseLanguageModel, websocket: WebSocket) -> tuple[Com logging.warning("Only one child agent connected successfully. Using it directly instead of a supervisor.") return child_agents[0].agent, agents_metadata - graph = create_supervisor_agent(llm, child_agents, checkpointer) + supervisor_llm = LLMManager.get_instance_for_role("supervisor") + graph = create_supervisor_agent(supervisor_llm, child_agents, checkpointer, + summary_llm=summary_llm, uitools_llm=uitools_llm) supervisor = SupervisorGraph( graph=graph, child_agents={ca.config.name: ca.agent for ca in child_agents}, @@ -76,10 +85,11 @@ async def build_agent(llm: BaseLanguageModel, websocket: WebSocket) -> tuple[Com async def _build_child_agents( - llm: BaseLanguageModel, agent_configs: list[AgentConfig], checkpointer: Checkpointer, websocket: WebSocket, + summary_llm: BaseLanguageModel | None = None, + uitools_llm: BaseLanguageModel | None = None, ) -> tuple[list[ChildAgent], list[dict]]: """ Attempt to build a child agent for each config, collecting successes and failures. @@ -92,10 +102,12 @@ async def _build_child_agents( for agent_cfg in agent_configs: try: + agent_llm = LLMManager.get_instance_for_agent(agent_cfg.name, agent_cfg.llm, agent_cfg.llm_model) tools = await _load_mcp_tools(agent_cfg, websocket) child_agents.append(ChildAgent( config=agent_cfg, - agent=create_child_agent(llm, tools, agent_cfg.system_prompt, checkpointer, agent_cfg), + agent=create_child_agent(agent_llm, tools, agent_cfg.system_prompt, checkpointer, agent_cfg, + summary_llm=summary_llm, uitools_llm=uitools_llm), )) agents_metadata.append({"name": agent_cfg.name, "status": "active"}) except (NoAgentAvailableError, Exception) as e: diff --git a/app/services/agent/loader.py b/app/services/agent/loader.py index dcdcda9b..db53ec15 100644 --- a/app/services/agent/loader.py +++ b/app/services/agent/loader.py @@ -175,6 +175,8 @@ class AgentConfig(BaseModel): toolset: Optional[str] = None human_validation_tools: list[str] = [] ready: bool = False + llm: Optional[str] = None + llm_model: Optional[str] = None def _load_k8s_config(): @@ -319,7 +321,9 @@ def _crd_to_agent_config(crd_obj: dict) -> AgentConfig: ca_bundle_ref=CABundleRef(**spec["caBundleRef"]) if spec.get("caBundleRef") else None, toolset=spec.get("toolSet", None), human_validation_tools=human_validation_tools, - ready=status.get("phase", "Failed") == "Ready" + ready=status.get("phase", "Failed") == "Ready", + llm=spec.get("llm", None), + llm_model=spec.get("llmModel", None), ) diff --git a/app/services/agent/supervisor.py b/app/services/agent/supervisor.py index 17bb5700..b0e702f3 100644 --- a/app/services/agent/supervisor.py +++ b/app/services/agent/supervisor.py @@ -94,6 +94,8 @@ def create_supervisor_agent( llm: BaseChatModel, child_agents: list[ChildAgent], checkpointer: Checkpointer, + summary_llm: BaseChatModel | None = None, + uitools_llm: BaseChatModel | None = None, ) -> CompiledStateGraph: """ Creates a supervisor agent that coordinates multiple child agents as tools. @@ -106,6 +108,8 @@ def create_supervisor_agent( llm: The language model instance to use for the supervisor agent. child_agents: List of child agents read from AIAgentConfig CRDs. checkpointer: Checkpointer for persisting agent state. + summary_llm: Optional LLM for summarization middleware. Falls back to llm. + uitools_llm: Optional LLM for UI tools middleware. Falls back to llm. Returns: A compiled LangGraph StateGraph ready to be invoked. @@ -128,8 +132,8 @@ def create_supervisor_agent( MessagesHistoryMiddleware(), cancel_human_validation_middleware(), inject_additional_kwargs_middleware(), - ui_tools_middleware(llm), - SummarizationMiddleware(model=llm, trigger=[("messages", 30), ("tokens", 30000)], keep=("messages", 15)), + ui_tools_middleware(uitools_llm or llm), + SummarizationMiddleware(model=summary_llm or llm, trigger=[("messages", 30), ("tokens", 30000)], keep=("messages", 15)), ], ) diff --git a/app/services/llm.py b/app/services/llm.py index d27d67ee..4fb18ca6 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -1,5 +1,6 @@ import os import logging +from typing import Optional from langchain_ollama import ChatOllama from langchain_google_genai import ChatGoogleGenerativeAI @@ -7,60 +8,63 @@ from langchain_core.language_models.llms import BaseLanguageModel from langchain_aws import ChatBedrockConverse +VALID_PROVIDERS = ("ollama", "gemini", "openai", "bedrock") + +ROLE_ENV_VARS = { + "supervisor": ("SUPERVISOR_LLM", "SUPERVISOR_MODEL"), + "uitools": ("UITOOLS_LLM", "UITOOLS_MODEL"), + "summary": ("SUMMARY_LLM", "SUMMARY_MODEL"), +} + + class LLMManager: - """ - Singleton manager for language model instances. - - This class ensures that only one instance of the language model is created - and reused throughout the application, avoiding redundant initializations - and ensuring consistent model configuration. - """ - _instance: BaseLanguageModel = None + _instances: dict[str, BaseLanguageModel] = {} + + @classmethod + def get_instance(cls, key: str = "default") -> BaseLanguageModel: + if key not in cls._instances: + cls._instances[key] = get_llm() + return cls._instances[key] + + @classmethod + def get_instance_for_role(cls, role: str) -> BaseLanguageModel: + if role in cls._instances: + return cls._instances[role] + llm = get_llm_for_role(role) + cls._instances[role] = llm + return llm + + @classmethod + def get_instance_for_agent(cls, agent_name: str, llm_provider: Optional[str], llm_model: Optional[str]) -> BaseLanguageModel: + key = f"agent:{agent_name}" + if key in cls._instances: + return cls._instances[key] + llm = get_llm_for_agent(llm_provider, llm_model) + cls._instances[key] = llm + return llm @classmethod - def get_instance(cls) -> BaseLanguageModel: - """ - Retrieves the singleton instance of the language model. - - If the instance doesn't exist yet, it initializes it by calling get_llm(). - Subsequent calls return the same instance. - - Returns: - The singleton language model instance. - """ - if cls._instance is None: - cls._instance = get_llm() - logging.info(f"Using model: {cls._instance}") - return cls._instance - -def get_llm() -> BaseLanguageModel: - """ - Selects and returns a language model instance based on environment variables. - - If the active LLM or the model is not configured, it raises a ValueError. - - If LLM mocking is enabled, it configures the connections to the mock server. - - Returns: - An instance of a language model. - - Raises: - ValueError: If the active LLM or the model is not configured. - """ - - activeLlm = get_active_llm() - model = get_llm_model(activeLlm) - + def reset(cls): + cls._instances = {} + + +def get_llm(llm_provider: Optional[str] = None, model: Optional[str] = None) -> BaseLanguageModel: + if llm_provider is None: + llm_provider = get_active_llm() + if model is None: + model = get_llm_model(llm_provider) + llm_mock_enabled = os.environ.get("LLM_MOCK_ENABLED", "false").lower() == "true" llm_mock_url = os.environ.get("LLM_MOCK_URL", "") if llm_mock_enabled: logging.info(f"Connecting to LLM Mock server at {llm_mock_url}") - if activeLlm == "ollama": + if llm_provider == "ollama": if llm_mock_enabled: return ChatOllama(model=model, base_url=llm_mock_url) - ollama_url = os.environ.get("OLLAMA_URL") return ChatOllama(model=model, base_url=ollama_url) - if activeLlm == "gemini": + if llm_provider == "gemini": if llm_mock_enabled: return ChatGoogleGenerativeAI( model=model, @@ -68,55 +72,62 @@ def get_llm() -> BaseLanguageModel: transport="rest" ) if model == "gemini-2.5-flash": - # Disable thinking budget for gemini-2.5-flash to avoid empty responses due to all tokens being used for thinking budget return ChatGoogleGenerativeAI(model=model, thinking_budget=0) - return ChatGoogleGenerativeAI(model=model) - if activeLlm == "openai": + if llm_provider == "openai": if llm_mock_enabled: return ChatOpenAI(model=model, base_url=llm_mock_url) - openai_url = os.environ.get("OPENAI_URL") if openai_url: return ChatOpenAI(model=model, base_url=openai_url) return ChatOpenAI(model=model) - if activeLlm == "bedrock": + if llm_provider == "bedrock": if llm_mock_enabled: os.environ["AWS_ENDPOINT_URL"] = llm_mock_url return ChatBedrockConverse(model=model) + raise ValueError(f"Unsupported LLM provider: {llm_provider}") + + +def get_llm_for_role(role: str) -> BaseLanguageModel: + env_llm, env_model = ROLE_ENV_VARS.get(role, (None, None)) + if not env_llm: + raise ValueError(f"Unknown role: {role}") + + provider = os.environ.get(env_llm, "").strip() or None + model = os.environ.get(env_model, "").strip() or None + + if provider and provider not in VALID_PROVIDERS: + raise ValueError(f"Invalid LLM provider '{provider}' for role '{role}'.") + + if provider and model: + return get_llm(llm_provider=provider, model=model) + return get_llm() + + +def get_llm_for_agent(llm_provider: Optional[str], llm_model: Optional[str]) -> BaseLanguageModel: + provider = (llm_provider or "").strip() or None + model = (llm_model or "").strip() or None + + if provider and provider not in VALID_PROVIDERS: + raise ValueError(f"Invalid LLM provider '{provider}' for agent.") + + if provider and model: + return get_llm(llm_provider=provider, model=model) + return get_llm() + + def get_active_llm() -> str: - """ - Retrieves the active LLM identifier from environment variables. - - Returns: - The active LLM as a string, or None if not set. - """ llm = os.environ.get("ACTIVE_LLM", "") - - if llm not in ["ollama", "gemini", "openai", "bedrock"]: + if llm not in VALID_PROVIDERS: raise ValueError("LLM not configured.") - return llm -def get_llm_model(llm: str) -> str: - """ - Retrieves the model name from environment variables. - - Args: - llm: The LLM identifier, one of 'ollama', 'gemini', 'openai', 'bedrock'. - - Returns: - The model name as a string. - """ +def get_llm_model(llm: str) -> str: model = None - if llm: model = os.environ.get(f"{llm.upper()}_MODEL") - if not model: raise ValueError("LLM Model not configured.") - return model - diff --git a/chart/agent/templates/ai-agent-deployment.yaml b/chart/agent/templates/ai-agent-deployment.yaml index 152659ae..51759d9d 100644 --- a/chart/agent/templates/ai-agent-deployment.yaml +++ b/chart/agent/templates/ai-agent-deployment.yaml @@ -103,11 +103,41 @@ spec: name: llm-secret key: GOOGLE_API_KEY optional: true - - name: ACTIVE_LLM + - name: SUPERVISOR_LLM valueFrom: configMapKeyRef: name: llm-config - key: ACTIVE_LLM + key: SUPERVISOR_LLM + optional: true + - name: UITOOLS_LLM + valueFrom: + configMapKeyRef: + name: llm-config + key: UITOOLS_LLM + optional: true + - name: SUMMARY_LLM + valueFrom: + configMapKeyRef: + name: llm-config + key: SUMMARY_LLM + optional: true + - name: SUPERVISOR_MODEL + valueFrom: + configMapKeyRef: + name: llm-config + key: SUPERVISOR_MODEL + optional: true + - name: UITOOLS_MODEL + valueFrom: + configMapKeyRef: + name: llm-config + key: UITOOLS_MODEL + optional: true + - name: SUMMARY_MODEL + valueFrom: + configMapKeyRef: + name: llm-config + key: SUMMARY_MODEL optional: true - name: LANGFUSE_PUBLIC_KEY valueFrom: diff --git a/chart/agent/templates/crds/ai.cattle.io_aiagentconfigs.yaml b/chart/agent/templates/crds/ai.cattle.io_aiagentconfigs.yaml index 14083219..e7e72531 100644 --- a/chart/agent/templates/crds/ai.cattle.io_aiagentconfigs.yaml +++ b/chart/agent/templates/crds/ai.cattle.io_aiagentconfigs.yaml @@ -3,7 +3,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.17.1 + controller-gen.kubebuilder.io/version: v0.21.0 name: aiagentconfigs.ai.cattle.io spec: group: ai.cattle.io @@ -66,17 +66,19 @@ spec: description: BuiltIn indicates if this is a built-in agent configuration type: boolean caBundleRef: - description: CABundleRef references a secret key containing a - PEM-encoded CA certificate bundle to trust when connecting to - this agent's MCP server + description: |- + CABundleRef references a secret key containing a PEM-encoded CA + certificate bundle to trust when connecting to this agent's MCP server. + The CA is scoped to this agent's connection only. properties: - name: - description: Name is the name of the secret in the agent namespace. - type: string key: default: ca.crt - description: Key is the key inside the secret that holds the - data. Defaults to "ca.crt". + description: |- + Key is the key inside the secret that holds the data. + Defaults to "ca.crt". + type: string + name: + description: Name is the name of the secret in the agent namespace. type: string required: - name @@ -95,6 +97,13 @@ spec: items: type: string type: array + llm: + description: LLM specifies the LLM provider to be used by the agent + type: string + llmModel: + description: LLMModel specifies the language model to be used by the + agent + type: string mcpURL: description: MCPURL is the Model Context Protocol server URL type: string diff --git a/chart/agent/templates/llm-config.yaml b/chart/agent/templates/llm-config.yaml index ef2bb704..4b250127 100644 --- a/chart/agent/templates/llm-config.yaml +++ b/chart/agent/templates/llm-config.yaml @@ -9,3 +9,9 @@ data: OPENAI_MODEL: "{{ .Values.openaiLlmModel }}" BEDROCK_MODEL: "{{ .Values.bedrockLlmModel }}" ACTIVE_LLM: "{{ .Values.activeLlm }}" + SUPERVISOR_LLM: "{{ .Values.supervisorLlm }}" + SUPERVISOR_MODEL: "{{ .Values.supervisorModel }}" + UITOOLS_LLM: "{{ .Values.uitoolsLlm }}" + UITOOLS_MODEL: "{{ .Values.uitoolsModel }}" + SUMMARY_LLM: "{{ .Values.summaryLlm }}" + SUMMARY_MODEL: "{{ .Values.summaryModel }}" diff --git a/chart/agent/values.yaml b/chart/agent/values.yaml index 461b1300..3a3ff195 100644 --- a/chart/agent/values.yaml +++ b/chart/agent/values.yaml @@ -26,6 +26,12 @@ awsBedrock: bearerToken: region: activeLlm: # ollama, gemini, openai, bedrock +supervisorLlm: "" +supervisorModel: "" +uitoolsLlm: "" +uitoolsModel: "" +summaryLlm: "" +summaryModel: "" # Enable RAG with embedded rancher documentation rag: enabled: false diff --git a/crd-generation/api/v1alpha1/aiagentconfig_types.go b/crd-generation/api/v1alpha1/aiagentconfig_types.go index c05e3e90..6884ff45 100644 --- a/crd-generation/api/v1alpha1/aiagentconfig_types.go +++ b/crd-generation/api/v1alpha1/aiagentconfig_types.go @@ -53,6 +53,14 @@ type AIAgentConfigSpec struct { // The CA is scoped to this agent's connection only. // +optional CABundleRef *SecretKeyRef `json:"caBundleRef,omitempty"` + + // LLMModel specifies the language model to be used by the agent + // +optional + LLMModel string `json:"llmModel,omitempty"` + + // LLM specifies the LLM provider to be used by the agent + // +optional + LLM string `json:"llm,omitempty"` } // SecretKeyRef identifies a key within a Kubernetes secret. diff --git a/tests/integration/test_multi_agent.py b/tests/integration/test_multi_agent.py index 8b8a41e9..f4bbe7c0 100644 --- a/tests/integration/test_multi_agent.py +++ b/tests/integration/test_multi_agent.py @@ -258,7 +258,7 @@ def test_single_prompt(): fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses) fake_llm.all_calls = [] - LLMManager._instance = fake_llm + LLMManager._instances["default"] = fake_llm try: with client.websocket_connect("/v1/ws/messages") as websocket: @@ -316,7 +316,7 @@ def test_single_prompt(): "Third call should include tool result from math-agent" finally: - LLMManager._instance = None + LLMManager.reset() def test_multiple_prompts(): @@ -355,7 +355,7 @@ def test_multiple_prompts(): fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses) fake_llm.all_calls = [] - LLMManager._instance = fake_llm + LLMManager._instances["default"] = fake_llm try: with client.websocket_connect("/v1/ws/messages") as websocket: @@ -432,7 +432,7 @@ def test_multiple_prompts(): assert isinstance(sixth_call[7], ToolMessage) and sixth_call[7].content == fake_llm_response_2 finally: - LLMManager._instance = None + LLMManager.reset() def test_delegate_to_child_agent_with_tool(): @@ -461,7 +461,7 @@ def test_delegate_to_child_agent_with_tool(): fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses) fake_llm.all_calls = [] - LLMManager._instance = fake_llm + LLMManager._instances["default"] = fake_llm try: with client.websocket_connect("/v1/ws/messages") as websocket: @@ -505,7 +505,7 @@ def test_delegate_to_child_agent_with_tool(): assert fake_llm.all_calls[3][0] == SystemMessage(content=SUPERVISOR_PROMPT) finally: - LLMManager._instance = None + LLMManager.reset() def test_delegate_to_child_agent_with_ui_tools(): @@ -547,7 +547,7 @@ def test_delegate_to_child_agent_with_ui_tools(): ] fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses) - LLMManager._instance = fake_llm + LLMManager._instances["default"] = fake_llm try: # Create agent configs with UI tools enabled @@ -632,6 +632,6 @@ def test_delegate_to_child_agent_with_ui_tools(): mock_load_configmap.assert_called() finally: - LLMManager._instance = None + LLMManager.reset() diff --git a/tests/integration/test_single_agent.py b/tests/integration/test_single_agent.py index 4303569a..1d23cf24 100644 --- a/tests/integration/test_single_agent.py +++ b/tests/integration/test_single_agent.py @@ -98,7 +98,7 @@ def test_websocket_single_prompt(): fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses) fake_llm.all_calls = [] # Reset call tracking - LLMManager._instance = fake_llm + LLMManager._instances["default"] = fake_llm try: messages = [] @@ -121,7 +121,7 @@ def test_websocket_single_prompt(): SystemMessage(content=IDENTITY_PREAMBLE), ], "First call should have system prompt and user message" finally: - LLMManager._instance = None + LLMManager.reset() def test_websocket_multiple_prompts(): """Tests multiple prompt-response interactions in sequence.""" @@ -137,7 +137,7 @@ def test_websocket_multiple_prompts(): fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses) fake_llm.all_calls = [] # Reset call tracking - LLMManager._instance = fake_llm + LLMManager._instances["default"] = fake_llm try: messages = [] @@ -167,7 +167,7 @@ def test_websocket_multiple_prompts(): HumanMessage(content="fake prompt 2"), ], "Second call should include conversation history" finally: - LLMManager._instance = None + LLMManager.reset() def test_websocket_tool_call(): """Tests agent interaction with tool calling.""" @@ -187,7 +187,7 @@ def test_websocket_tool_call(): fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses) fake_llm.all_calls = [] # Reset call tracking - LLMManager._instance = fake_llm + LLMManager._instances["default"] = fake_llm try: messages = [] @@ -217,7 +217,7 @@ def test_websocket_tool_call(): assert isinstance(second_call[3], AIMessage) and second_call[3].tool_calls[0]["name"] == "add", "Second call should have AI message with tool call" assert isinstance(second_call[4], ToolMessage) and second_call[4].content == "sum is 9", "Second call should have tool result" finally: - LLMManager._instance = None + LLMManager.reset() def test_conversation_history(): """Tests that conversation history is maintained across multiple prompts.""" @@ -252,7 +252,7 @@ def test_conversation_history(): fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses) fake_llm.all_calls = [] # Reset call tracking - LLMManager._instance = fake_llm + LLMManager._instances["default"] = fake_llm try: messages = [] @@ -326,7 +326,7 @@ def test_conversation_history(): ], "Fifth call should include full conversation history" finally: - LLMManager._instance = None + LLMManager.reset() def test_websocket_with_ui_tools(): @@ -358,7 +358,7 @@ def test_websocket_with_ui_tools(): fake_llm = FakeMessagesListChatModelWithTools(responses=fake_llm_responses) fake_llm.all_calls = [] - LLMManager._instance = fake_llm + LLMManager._instances["default"] = fake_llm try: agent_config = AgentConfig( @@ -434,4 +434,4 @@ def test_websocket_with_ui_tools(): mock_load_configmap.assert_called(), "UI tools config should be loaded from ConfigMap" finally: - LLMManager._instance = None + LLMManager.reset() diff --git a/tests/unit/routers/test_websocket.py b/tests/unit/routers/test_websocket.py index fc294c07..374ff5d1 100644 --- a/tests/unit/routers/test_websocket.py +++ b/tests/unit/routers/test_websocket.py @@ -79,12 +79,10 @@ def mock_dependencies(): @pytest.mark.asyncio async def test_websocket_endpoint(mock_dependencies): mock_ws = MockWebSocket(messages=["test message"]) - mock_llm = MagicMock() - - await websocket_endpoint(mock_ws, None, mock_llm) + await websocket_endpoint(mock_ws, None) assert mock_ws.accepted - mock_dependencies["build_agent"].assert_called_once_with(llm=mock_llm, websocket=mock_ws) + mock_dependencies["build_agent"].assert_called_once_with(websocket=mock_ws) mock_dependencies["call_agent"].assert_awaited_once() call_kwargs = mock_dependencies["call_agent"].call_args.kwargs @@ -97,9 +95,7 @@ async def test_websocket_endpoint_context_message(mock_dependencies): mock_ws = MockWebSocket(messages=[ '{"prompt": "show all pods", "context": {"namespace": "default", "cluster": "local"}}' ]) - mock_llm = MagicMock() - - await websocket_endpoint(mock_ws, None, mock_llm) + await websocket_endpoint(mock_ws, None) mock_dependencies["build_agent"].assert_called_once() mock_dependencies["call_agent"].assert_awaited_once() @@ -115,9 +111,7 @@ async def test_websocket_endpoint_context_message(mock_dependencies): @pytest.mark.asyncio async def test_websocket_endpoint_sends_chat_metadata(mock_dependencies): mock_ws = MockWebSocket(messages=["hello"]) - mock_llm = MagicMock() - - await websocket_endpoint(mock_ws, None, mock_llm) + await websocket_endpoint(mock_ws, None) # First message sent should be chat metadata metadata_msg = mock_ws._send_queue[0] @@ -131,9 +125,7 @@ async def test_websocket_endpoint_no_agent_available(mock_dependencies): from app.services.agent.factory import NoAgentAvailableError mock_dependencies["build_agent"].side_effect = NoAgentAvailableError("No agents") mock_ws = MockWebSocket(messages=[]) - mock_llm = MagicMock() - - await websocket_endpoint(mock_ws, None, mock_llm) + await websocket_endpoint(mock_ws, None) assert mock_ws.accepted assert mock_ws.closed @@ -145,9 +137,7 @@ async def test_websocket_endpoint_no_agent_available(mock_dependencies): @pytest.mark.asyncio async def test_websocket_endpoint_with_thread_id(mock_dependencies): mock_ws = MockWebSocket(messages=["hello"]) - mock_llm = MagicMock() - - await websocket_endpoint(mock_ws, "custom-thread-id", mock_llm) + await websocket_endpoint(mock_ws, "custom-thread-id") assert mock_ws.accepted metadata_msg = mock_ws._send_queue[0] @@ -159,9 +149,7 @@ async def test_websocket_endpoint_error_during_processing(mock_dependencies): """Test that errors during message processing are sent to the client.""" mock_dependencies["call_agent"].side_effect = Exception("Something went wrong") mock_ws = MockWebSocket(messages=["test", "second"]) - mock_llm = MagicMock() - - await websocket_endpoint(mock_ws, None, mock_llm) + await websocket_endpoint(mock_ws, None) # Should have sent error and end message error_messages = [m for m in mock_ws._send_queue if "" in m] diff --git a/tests/unit/services/agent/test_factory.py b/tests/unit/services/agent/test_factory.py index 5011db63..bae060b7 100644 --- a/tests/unit/services/agent/test_factory.py +++ b/tests/unit/services/agent/test_factory.py @@ -23,102 +23,112 @@ # ============================================================================ @pytest.mark.asyncio +@patch('app.services.agent.factory.LLMManager') @patch('app.services.agent.factory.load_agent_configs') @patch('app.services.agent.factory._load_mcp_tools') @patch('app.services.agent.factory.create_child_agent') -async def test_create_agent_single_agent(mock_create_child, mock_load_tools, mock_load_configs): +async def test_create_agent_single_agent(mock_create_child, mock_load_tools, mock_load_configs, mock_llm_manager): """Verify build_agent creates a child agent when one config is available.""" # Setup mocks mock_llm = MagicMock() + mock_summary_llm = MagicMock() + mock_uitools_llm = MagicMock() + mock_llm_manager.get_instance_for_agent.return_value = mock_llm + mock_llm_manager.get_instance_for_role.side_effect = lambda role: mock_summary_llm if role == "summary" else mock_uitools_llm + mock_websocket = MagicMock() mock_memory_manager = MagicMock() mock_checkpointer = MagicMock() mock_memory_manager.get_checkpointer.return_value = mock_checkpointer mock_websocket.app.memory_manager = mock_memory_manager - + mock_agent_config = MagicMock() mock_agent_config.name = "RancherAgent" mock_agent_config.system_prompt = "Test prompt" mock_load_configs.return_value = [mock_agent_config] - + mock_tools = [MagicMock()] mock_load_tools.return_value = mock_tools mock_agent = MagicMock() mock_create_child.return_value = mock_agent - + # Execute - result = await build_agent(mock_llm, mock_websocket) - + result = await build_agent(mock_websocket) + # Verify assert result[0] == mock_agent assert result[1] == [{"name": "RancherAgent", "status": "active"}] mock_load_tools.assert_called_once_with(mock_agent_config, mock_websocket) mock_create_child.assert_called_once_with( - mock_llm, mock_tools, "Test prompt", mock_checkpointer, mock_agent_config + mock_llm, mock_tools, "Test prompt", mock_checkpointer, mock_agent_config, + summary_llm=mock_summary_llm, uitools_llm=mock_uitools_llm ) @pytest.mark.asyncio +@patch('app.services.agent.factory.LLMManager') @patch('app.services.agent.factory.load_agent_configs') @patch('app.services.agent.factory.create_supervisor_agent') @patch('app.services.agent.factory.create_child_agent') @patch('app.services.agent.factory.create_mcp_client') @patch('app.services.agent.factory._update_agent_status') -async def test_build_agent_three_agents(mock_update_status, mock_create_client, mock_create_child, mock_create_parent, mock_load_configs): +async def test_build_agent_three_agents(mock_update_status, mock_create_client, mock_create_child, mock_create_parent, mock_load_configs, mock_llm_manager): """Verify build_agent creates a supervisor agent when three configs are available.""" # Setup mocks mock_llm = MagicMock() + mock_llm_manager.get_instance_for_agent.return_value = mock_llm + mock_llm_manager.get_instance_for_role.return_value = mock_llm + mock_websocket = MagicMock() mock_memory_manager = MagicMock() mock_checkpointer = MagicMock() mock_memory_manager.get_checkpointer.return_value = mock_checkpointer mock_websocket.app.memory_manager = mock_memory_manager - + mock_config1 = MagicMock() mock_config1.name = "RancherAgent" mock_config1.description = "Rancher core agent" mock_config1.system_prompt = "Prompt 1" - + mock_config2 = MagicMock() mock_config2.name = "FleetAgent" mock_config2.description = "Fleet agent" mock_config2.system_prompt = "Prompt 2" - + mock_config3 = MagicMock() mock_config3.name = "HarvesterAgent" mock_config3.description = "Harvester agent" mock_config3.system_prompt = "Prompt 3" - + mock_load_configs.return_value = [mock_config1, mock_config2, mock_config3] - + # Mock MCP client mock_client_instance = MagicMock() mock_tools = [MagicMock()] mock_client_instance.get_tools = AsyncMock(return_value=mock_tools) mock_create_client.return_value = mock_client_instance - + mock_parent_agent = MagicMock() mock_create_parent.return_value = mock_parent_agent mock_create_child.return_value = MagicMock() - + # Execute - result = await build_agent(mock_llm, mock_websocket) - + result = await build_agent(mock_websocket) + # Verify - build_agent wraps multi-agent result in SupervisorGraph from app.services.agent.supervisor import SupervisorGraph assert isinstance(result[0], SupervisorGraph) assert result[0]._graph == mock_parent_agent mock_create_parent.assert_called_once() - + # Verify parent was called with correct subagents call_args = mock_create_parent.call_args - assert call_args[0][0] == mock_llm subagents = call_args[0][1] assert len(subagents) == 3 assert subagents[0].config.name == "RancherAgent" assert subagents[1].config.name == "FleetAgent" assert subagents[2].config.name == "HarvesterAgent" - + # Verify metadata includes all agents metadata = result[1] assert len(metadata) == 3 @@ -126,79 +136,79 @@ async def test_build_agent_three_agents(mock_update_status, mock_create_client, @pytest.mark.asyncio +@patch('app.services.agent.factory.LLMManager') @patch('app.services.agent.factory.load_agent_configs') @patch('app.services.agent.factory.create_supervisor_agent') @patch('app.services.agent.factory.create_child_agent') @patch('app.services.agent.factory.create_mcp_client') @patch('app.services.agent.factory._update_agent_status') -async def test_build_agent_filters_tools_by_toolset(mock_update_status, mock_create_client, mock_create_child, mock_create_parent, mock_load_configs): +async def test_build_agent_filters_tools_by_toolset(mock_update_status, mock_create_client, mock_create_child, mock_create_parent, mock_load_configs, mock_llm_manager): """Verify build_agent filters tools based on toolset configuration.""" # Setup mocks mock_llm = MagicMock() + mock_llm_manager.get_instance_for_agent.return_value = mock_llm + mock_llm_manager.get_instance_for_role.return_value = mock_llm + mock_websocket = MagicMock() mock_memory_manager = MagicMock() mock_checkpointer = MagicMock() mock_memory_manager.get_checkpointer.return_value = mock_checkpointer mock_websocket.app.memory_manager = mock_memory_manager - + mock_config1 = MagicMock() mock_config1.name = "RancherAgent" mock_config1.description = "Rancher agent with specific toolset" mock_config1.system_prompt = "Prompt 1" - mock_config1.toolset = "rancher-core" # Specify toolset filter - + mock_config1.toolset = "rancher-core" + mock_config2 = MagicMock() mock_config2.name = "FleetAgent" mock_config2.description = "Fleet agent without toolset filter" mock_config2.system_prompt = "Prompt 2" - mock_config2.toolset = None # No toolset filter - + mock_config2.toolset = None + mock_load_configs.return_value = [mock_config1, mock_config2] - + # Mock MCP client with tools that have different toolsets - # Create mock tools with metadata tool_rancher_core = MagicMock() tool_rancher_core.name = "rancher_tool" tool_rancher_core.metadata = {"_meta": {"toolset": "rancher-core"}} - + tool_rancher_extensions = MagicMock() tool_rancher_extensions.name = "extensions_tool" tool_rancher_extensions.metadata = {"_meta": {"toolset": "rancher-extensions"}} - + tool_fleet = MagicMock() tool_fleet.name = "fleet_tool" tool_fleet.metadata = {"_meta": {"toolset": "fleet"}} - + tool_no_toolset = MagicMock() tool_no_toolset.name = "generic_tool" tool_no_toolset.metadata = {} - + all_tools = [tool_rancher_core, tool_rancher_extensions, tool_fleet, tool_no_toolset] - + mock_client_instance = MagicMock() mock_client_instance.get_tools = AsyncMock(return_value=all_tools) mock_create_client.return_value = mock_client_instance - + mock_parent_agent = MagicMock() mock_create_parent.return_value = mock_parent_agent mock_create_child.return_value = MagicMock() - + # Execute - result = await build_agent(mock_llm, mock_websocket) - + result = await build_agent(mock_websocket) + # Verify - build_agent wraps multi-agent result in SupervisorGraph from app.services.agent.supervisor import SupervisorGraph assert isinstance(result[0], SupervisorGraph) - + # Verify subagents passed to create_supervisor_agent have correct tools call_args = mock_create_parent.call_args subagents = call_args[0][1] assert len(subagents) == 2 - - # First subagent (RancherAgent) should have only rancher-core tools + assert subagents[0].config.name == "RancherAgent" - - # Second subagent (FleetAgent) should have all tools (no toolset filter) assert subagents[1].config.name == "FleetAgent" # Verify create_child_agent was called with filtered tools @@ -214,63 +224,65 @@ async def test_build_agent_filters_tools_by_toolset(mock_update_status, mock_cre @pytest.mark.asyncio +@patch('app.services.agent.factory.LLMManager') @patch('app.services.agent.factory.load_agent_configs') @patch('app.services.agent.factory.create_supervisor_agent') @patch('app.services.agent.factory.create_child_agent') @patch('app.services.agent.factory.create_mcp_client') @patch('app.services.agent.factory._update_agent_status') -async def test_build_agent_one_fails_mcp_connection(mock_update_status, mock_create_client, mock_create_child, mock_create_parent, mock_load_configs): +async def test_build_agent_one_fails_mcp_connection(mock_update_status, mock_create_client, mock_create_child, mock_create_parent, mock_load_configs, mock_llm_manager): """Verify build_agent handles MCP connection failure for one agent and continues with others.""" # Setup mocks mock_llm = MagicMock() + mock_llm_manager.get_instance_for_agent.return_value = mock_llm + mock_llm_manager.get_instance_for_role.return_value = mock_llm + mock_websocket = MagicMock() mock_memory_manager = MagicMock() mock_checkpointer = MagicMock() mock_memory_manager.get_checkpointer.return_value = mock_checkpointer mock_websocket.app.memory_manager = mock_memory_manager - + mock_config1 = MagicMock() mock_config1.name = "Agent1" mock_config1.description = "First agent" mock_config1.system_prompt = "Prompt 1" mock_config1.ready = False - + mock_config2 = MagicMock() mock_config2.name = "Agent2" mock_config2.description = "Second agent" mock_config2.system_prompt = "Prompt 2" mock_config2.ready = False - + mock_config3 = MagicMock() mock_config3.name = "Agent3" mock_config3.description = "Third agent" mock_config3.system_prompt = "Prompt 3" mock_config3.ready = False - + mock_load_configs.return_value = [mock_config1, mock_config2, mock_config3] - + # Mock MCP client - first one fails, others succeed - # Create three different client instances mock_client_fail = MagicMock() mock_client_fail.get_tools = AsyncMock(side_effect=Exception("Connection refused: invalid MCP URL")) - + mock_client_success1 = MagicMock() mock_tools = [MagicMock()] mock_client_success1.get_tools = AsyncMock(return_value=mock_tools) - + mock_client_success2 = MagicMock() mock_client_success2.get_tools = AsyncMock(return_value=mock_tools) - - # Return different clients on each call + mock_create_client.side_effect = [mock_client_fail, mock_client_success1, mock_client_success2] - + mock_parent_agent = MagicMock() mock_create_parent.return_value = mock_parent_agent mock_create_child.return_value = MagicMock() - + # Execute - result = await build_agent(mock_llm, mock_websocket) - + result = await build_agent(mock_websocket) + # Should return supervisor since 2 agents succeeded from app.services.agent.supervisor import SupervisorGraph assert isinstance(result[0], SupervisorGraph) @@ -278,12 +290,11 @@ async def test_build_agent_one_fails_mcp_connection(mock_update_status, mock_cre # Verify parent was called with correct subagents call_args = mock_create_parent.call_args - assert call_args[0][0] == mock_llm subagents = call_args[0][1] assert len(subagents) == 2 assert subagents[0].config.name == "Agent2" assert subagents[1].config.name == "Agent3" - + # Verify metadata includes all three agents with correct status metadata = result[1] assert len(metadata) == 3 @@ -294,51 +305,54 @@ async def test_build_agent_one_fails_mcp_connection(mock_update_status, mock_cre assert metadata[1]["name"] == "Agent2" assert metadata[2]["status"] == "active" assert metadata[2]["name"] == "Agent3" - - + # Verify status update was called for the failed agent update_calls = [call for call in mock_update_status.call_args_list if call[0][1] == False] assert len(update_calls) > 0 @pytest.mark.asyncio +@patch('app.services.agent.factory.LLMManager') @patch('app.services.agent.factory.load_agent_configs') @patch('app.services.agent.factory.create_mcp_client') @patch('app.services.agent.factory._update_agent_status') -async def test_build_agent_all_fail_mcp_connection(mock_update_status, mock_create_client, mock_load_configs): +async def test_build_agent_all_fail_mcp_connection(mock_update_status, mock_create_client, mock_load_configs, mock_llm_manager): """Verify build_agent raises NoAgentAvailableError when all agents fail MCP connection.""" # Setup mocks mock_llm = MagicMock() + mock_llm_manager.get_instance_for_agent.return_value = mock_llm + mock_llm_manager.get_instance_for_role.return_value = mock_llm + mock_websocket = MagicMock() mock_memory_manager = MagicMock() mock_checkpointer = MagicMock() mock_memory_manager.get_checkpointer.return_value = mock_checkpointer mock_websocket.app.memory_manager = mock_memory_manager - + mock_config1 = MagicMock() mock_config1.name = "Agent1" mock_config1.description = "First agent" mock_config1.ready = False - + mock_config2 = MagicMock() mock_config2.name = "Agent2" mock_config2.description = "Second agent" mock_config2.ready = False - + mock_load_configs.return_value = [mock_config1, mock_config2] - + # Mock MCP client - all fail mock_client_fail = MagicMock() mock_client_fail.get_tools = AsyncMock(side_effect=Exception("Connection refused: invalid MCP URL")) - + mock_create_client.return_value = mock_client_fail - + # Execute and verify exception with pytest.raises(NoAgentAvailableError) as exc_info: - await build_agent(mock_llm, mock_websocket) - + await build_agent(mock_websocket) + assert "No agents could be created" in str(exc_info.value) - + # Verify status update was called for both failed agents assert mock_update_status.call_count >= 2 @@ -347,16 +361,15 @@ async def test_build_agent_all_fail_mcp_connection(mock_update_status, mock_crea @patch('app.services.agent.factory.load_agent_configs') async def test_build_agent_no_configs_raises_error(mock_load_configs): """Verify build_agent raises NoAgentAvailableError when no configs are available.""" - mock_llm = MagicMock() mock_websocket = MagicMock() mock_memory_manager = MagicMock() mock_websocket.app.memory_manager = mock_memory_manager - + mock_load_configs.return_value = [] - + with pytest.raises(NoAgentAvailableError) as exc_info: - await build_agent(mock_llm, mock_websocket) - + await build_agent(mock_websocket) + assert "No agent configurations available" in str(exc_info.value)