diff --git a/python/api/mcp_server_get_detail.py b/python/api/mcp_server_get_detail.py index 0608aadc5b..cbd6568785 100644 --- a/python/api/mcp_server_get_detail.py +++ b/python/api/mcp_server_get_detail.py @@ -5,13 +5,12 @@ class McpServerGetDetail(ApiHandler): - async def process(self, input: dict[Any, Any], request: Request) -> dict[Any, Any] | Response: - - # try: - server_name = input.get("server_name") - if not server_name: - return {"success": False, "error": "Missing server_name"} - detail = MCPConfig.get_instance().get_server_detail(server_name) - return {"success": True, "detail": detail} - # except Exception as e: - # return {"success": False, "error": str(e)} + async def process( + self, input: dict[Any, Any], request: Request + ) -> dict[Any, Any] | Response: + server_name = input.get("server_name") + if not server_name: + return {"success": False, "error": "Missing server_name"} + project_name = input.get("project_name", "") + detail = MCPConfig.get_instance().get_server_detail(server_name, project_name) + return {"success": True, "detail": detail} diff --git a/python/api/project_mcp_servers.py b/python/api/project_mcp_servers.py new file mode 100644 index 0000000000..7c8c926bfe --- /dev/null +++ b/python/api/project_mcp_servers.py @@ -0,0 +1,74 @@ +import time +from python.helpers.api import ApiHandler, Request, Response +from typing import Any + +from python.helpers import projects, dirty_json +from python.helpers.settings import get_settings +from python.helpers.mcp_handler import MCPConfig + + +class ProjectMcpServers(ApiHandler): + async def process( + self, input: dict[Any, Any], request: Request + ) -> dict[Any, Any] | Response: + action = input.get("action", "") + project_name = input.get("project_name") + + if action == "list_global": + return self._list_global_servers() + + if not project_name: + return {"success": False, "error": "Missing project_name"} + + if action == "load": + return self._load_config(project_name) + elif action == "save": + config = input.get("config", "") + return self._save_config(project_name, config) + elif action == "apply": + config = input.get("config", "") + return self._apply_config(project_name, config) + elif action == "status": + return self._get_status(project_name) + else: + return {"success": False, "error": "Invalid action"} + + def _load_config(self, project_name: str) -> dict[str, Any]: + config = projects.load_project_mcp_servers(project_name) + return {"success": True, "config": config} + + def _save_config(self, project_name: str, config: str) -> dict[str, Any]: + projects.save_project_mcp_servers(project_name, config) + return {"success": True} + + def _apply_config(self, project_name: str, config: str) -> dict[str, Any]: + projects.save_project_mcp_servers(project_name, config) + return {"success": True, "status": []} + + def _get_status(self, project_name: str) -> dict[str, Any]: + status = self._get_server_status(project_name) + return {"success": True, "status": status} + + def _get_server_status(self, project_name: str) -> list[dict[str, Any]]: + mcp = MCPConfig.get_instance() + servers = mcp.get_project_servers(project_name) + result = [] + for server in servers: + tool_count = len(server.get_tools()) + error = server.get_error() + result.append( + { + "name": server.name, + "description": server.description, + "connected": tool_count > 0 and not error, + "error": error, + "tool_count": tool_count, + } + ) + return result + + def _list_global_servers(self) -> dict[str, Any]: + global_config_str = get_settings().get("mcp_servers", '{"mcpServers": {}}') + global_config = dirty_json.parse(global_config_str) or {"mcpServers": {}} + servers = global_config.get("mcpServers", {}) + return {"success": True, "servers": servers} diff --git a/python/extensions/system_prompt/_10_system_prompt.py b/python/extensions/system_prompt/_10_system_prompt.py index 9a17c0ad9a..b37add2bb2 100644 --- a/python/extensions/system_prompt/_10_system_prompt.py +++ b/python/extensions/system_prompt/_10_system_prompt.py @@ -7,12 +7,11 @@ class SystemPrompt(Extension): - async def execute( self, system_prompt: list[str] = [], loop_data: LoopData = LoopData(), - **kwargs: Any + **kwargs: Any, ): # append main system prompt and tools main = get_main_prompt(self.agent) @@ -44,13 +43,18 @@ def get_tools_prompt(agent: Agent): def get_mcp_tools_prompt(agent: Agent): mcp_config = MCPConfig.get_instance() - if mcp_config.servers: + project_name = projects.get_context_project_name(agent.context) or "" + + if project_name and project_name in mcp_config.project_servers: + has_servers = bool(mcp_config.project_servers[project_name]) + else: + has_servers = bool(mcp_config.servers) + + if has_servers: pre_progress = agent.context.log.progress - agent.context.log.set_progress( - "Collecting MCP tools" - ) # MCP might be initializing, better inform via progress bar - tools = MCPConfig.get_instance().get_tools_prompt() - agent.context.log.set_progress(pre_progress) # return original progress + agent.context.log.set_progress("Collecting MCP tools") + tools = mcp_config.get_tools_prompt(project_name=project_name) + agent.context.log.set_progress(pre_progress) return tools return "" diff --git a/python/helpers/mcp_handler.py b/python/helpers/mcp_handler.py index 1a16acb49e..f6f45174ee 100644 --- a/python/helpers/mcp_handler.py +++ b/python/helpers/mcp_handler.py @@ -58,7 +58,13 @@ def _determine_server_type(config_dict: dict) -> str: # First check if type is explicitly specified if "type" in config_dict: server_type = config_dict["type"].lower() - if server_type in ["sse", "http-stream", "streaming-http", "streamable-http", "http-streaming"]: + if server_type in [ + "sse", + "http-stream", + "streaming-http", + "streamable-http", + "http-streaming", + ]: return "MCPServerRemote" elif server_type == "stdio": return "MCPServerLocal" @@ -77,7 +83,12 @@ def _determine_server_type(config_dict: dict) -> str: def _is_streaming_http_type(server_type: str) -> bool: """Check if the server type is a streaming HTTP variant.""" - return server_type.lower() in ["http-stream", "streaming-http", "streamable-http", "http-streaming"] + return server_type.lower() in [ + "http-stream", + "streaming-http", + "streamable-http", + "http-streaming", + ] def initialize_mcp(mcp_servers_config: str): @@ -93,19 +104,25 @@ def initialize_mcp(mcp_servers_config: str): temp=False, ) - PrintStyle( - background_color="black", font_color="red", padding=True - ).print(f"Failed to update MCP settings: {e}") + PrintStyle(background_color="black", font_color="red", padding=True).print( + f"Failed to update MCP settings: {e}" + ) class MCPTool(Tool): - """MCP Tool wrapper""" - async def execute(self, **kwargs: Any): + from python.helpers import projects + + project_name = "" + try: + project_name = projects.get_context_project_name(self.agent.context) or "" + except Exception: + pass + error = "" try: response: CallToolResult = await MCPConfig.get_instance().call_tool( - self.name, kwargs + self.name, kwargs, project_name=project_name ) message = "\n\n".join( [item.text for item in response.content if item.type == "text"] @@ -219,7 +236,8 @@ class MCPServerRemote(BaseModel): init_timeout: int = Field(default=0) tool_timeout: int = Field(default=0) verify: bool = Field(default=True, description="Verify SSL certificates") - disabled: bool = Field(default=False) + # disabled: bool = Field(default=False) + disabled_tools: list[str] = Field(default_factory=list) __lock: ClassVar[threading.Lock] = PrivateAttr(default=threading.Lock()) __client: Optional["MCPClientRemote"] = PrivateAttr(default=None) @@ -269,14 +287,14 @@ def update(self, config: dict[str, Any]) -> "MCPServerRemote": "tool_timeout", "disabled", "verify", + "disabled_tools", # Added so it can be updated via config ]: if key == "name": value = normalize_name(value) if key == "serverUrl": - key = "url" # remap serverUrl to url + key = "url" setattr(self, key, value) - # We already run in an event loop, dont believe Pylance return asyncio.run(self.__on_update()) async def __on_update(self) -> "MCPServerRemote": @@ -299,6 +317,7 @@ class MCPServerLocal(BaseModel): tool_timeout: int = Field(default=0) verify: bool = Field(default=True, description="Verify SSL certificates") disabled: bool = Field(default=False) + disabled_tools: list[str] = Field(default_factory=list) # Added field __lock: ClassVar[threading.Lock] = PrivateAttr(default=threading.Lock()) __client: Optional["MCPClientLocal"] = PrivateAttr(default=None) @@ -349,6 +368,7 @@ def update(self, config: dict[str, Any]) -> "MCPServerLocal": "init_timeout", "tool_timeout", "disabled", + "disabled_tools", ]: if key == "name": value = normalize_name(value) @@ -373,6 +393,7 @@ async def __on_update(self) -> "MCPServerLocal": class MCPConfig(BaseModel): servers: list[MCPServer] = Field(default_factory=list) disconnected_servers: list[dict[str, Any]] = Field(default_factory=list) + project_servers: dict[str, list[MCPServer]] = Field(default_factory=dict) __lock: ClassVar[threading.Lock] = PrivateAttr(default=threading.Lock()) __instance: ClassVar[Any] = PrivateAttr(default=None) __initialized: ClassVar[bool] = PrivateAttr(default=False) @@ -394,11 +415,8 @@ def update(cls, config_str: str) -> Any: with cls.__lock: servers_data: List[Dict[str, Any]] = [] # Default to empty list - if ( - config_str and config_str.strip() - ): # Only parse if non-empty and not just whitespace + if config_str and config_str.strip(): try: - # Try with standard json.loads first, as it should handle escaped strings correctly parsed_value = dirty_json.try_parse(config_str) normalized = cls.normalize_config(parsed_value) @@ -423,9 +441,7 @@ def update(cls, config_str: str) -> Any: f"Error: Parsed MCP config (from json.loads) top-level structure is not a list. Config string was: '{config_str}'" ) # servers_data remains empty - except ( - Exception - ) as e_json: # Catch json.JSONDecodeError specifically if possible, or general Exception + except Exception as e_json: # Catch json.JSONDecodeError specifically if possible, or general Exception PrintStyle.error( f"Error parsing MCP config string: {e_json}. Config string was: '{config_str}'" ) @@ -528,11 +544,11 @@ def __init__(self, servers_list: List[Dict[str, Any]]): self.disconnected_servers = [] if not isinstance(servers_list, Iterable): - ( - PrintStyle( - background_color="grey", font_color="red", padding=True - ).print("MCPConfig::__init__::servers_list must be a list") - ) + # ( + # PrintStyle( + # background_color="grey", font_color="red", padding=True + # ).print("MCPConfig::__init__::servers_list must be a list") + # ) return for server_item in servers_list: @@ -610,11 +626,58 @@ def __init__(self, servers_list: List[Dict[str, Any]]): f"MCPConfig::__init__: Failed to create MCPServer '{server_name}': {error_msg}" ) ) - # add to failed servers self.disconnected_servers.append( {"config": server_item, "error": error_msg, "name": server_name} ) + def load_project_servers(self, project_name: str, config_str: str): + with self.__lock: + self.unload_project_servers_unlocked(project_name) + + if not config_str or not config_str.strip(): + return + + try: + parsed = dirty_json.try_parse(config_str) + servers_data = self.normalize_config(parsed) + except Exception: + return + + project_server_list: list[MCPServer] = [] + for server_item in servers_data: + if not isinstance(server_item, dict): + continue + if server_item.get("disabled", False): + continue + + server_name = server_item.get("name", "") + if not server_name: + continue + + try: + if server_item.get("url") or server_item.get("serverUrl"): + project_server_list.append(MCPServerRemote(server_item)) + else: + project_server_list.append(MCPServerLocal(server_item)) + except Exception as e: + PrintStyle( + background_color="grey", font_color="red", padding=True + ).print(f"Project MCP server '{server_name}' failed: {e}") + + self.project_servers[project_name] = project_server_list + + def unload_project_servers(self, project_name: str): + with self.__lock: + self.unload_project_servers_unlocked(project_name) + + def unload_project_servers_unlocked(self, project_name: str): + if project_name in self.project_servers: + del self.project_servers[project_name] + + def get_project_servers(self, project_name: str) -> list[MCPServer]: + with self.__lock: + return list(self.project_servers.get(project_name, [])) + def get_server_log(self, server_name: str) -> str: with self.__lock: for server in self.servers: @@ -664,10 +727,20 @@ def get_servers_status(self) -> list[dict[str, Any]]: return result - def get_server_detail(self, server_name: str) -> dict[str, Any]: + def get_server_detail( + self, server_name: str, project_name: str = "" + ) -> dict[str, Any]: with self.__lock: - for server in self.servers: - if server.name == server_name: + # When project is active, use ONLY project servers (they are separate from global) + # This matches the UI design: "These are separate from global MCP servers" + if project_name and project_name in self.project_servers: + all_servers = list(self.project_servers[project_name]) + else: + all_servers = list(self.servers) + + normalized_search_name = normalize_name(server_name) + for server in all_servers: + if server.name == normalized_search_name: try: tools = server.get_tools() except Exception: @@ -676,6 +749,7 @@ def get_server_detail(self, server_name: str) -> dict[str, Any]: "name": server.name, "description": server.description, "tools": tools, + "disabled_tools": list(server.disabled_tools), } return {} @@ -695,87 +769,149 @@ def get_tools(self) -> List[dict[str, dict[str, Any]]]: tools.append({f"{server.name}.{tool['name']}": tool_copy}) return tools - def get_tools_prompt(self, server_name: str = "") -> str: - """Get a prompt for all tools""" - - # just to wait for pending initialization + def get_tools_prompt(self, server_name: str = "", project_name: str = "") -> str: with self.__lock: pass + # When project is active, use ONLY project servers (they are separate from global) + # This matches the UI design: "These are separate from global MCP servers" + if project_name and project_name in self.project_servers: + all_servers = list(self.project_servers[project_name]) + else: + all_servers = list(self.servers) + prompt = '## "Remote (MCP Server) Agent Tools" available:\n\n' server_names = [] - for server in self.servers: + for server in all_servers: if not server_name or server.name == server_name: server_names.append(server.name) if server_name and server_name not in server_names: raise ValueError(f"Server {server_name} not found") - for server in self.servers: + for server in all_servers: if server.name in server_names: - server_name = server.name - prompt += f"### {server_name}\n" + current_server_name = server.name + prompt += f"### {current_server_name}\n" prompt += f"{server.description}\n" tools = server.get_tools() for tool in tools: prompt += ( - f"\n### {server_name}.{tool['name']}:\n" + f"\n### {current_server_name}.{tool['name']}:\n" f"{tool['description']}\n\n" - # f"#### Categories:\n" - # f"* kind: MCP Server Tool\n" - # f'* server: "{server_name}" ({server.description})\n\n' - # f"#### Arguments:\n" ) - input_schema = ( json.dumps(tool["input_schema"]) if tool["input_schema"] else "" ) - prompt += f"#### Input schema for tool_args:\n{input_schema}\n" - prompt += "\n" - prompt += ( f"#### Usage:\n" f"{{\n" - # f' "observations": ["..."],\n' # TODO: this should be a prompt file with placeholders f' "thoughts": ["..."],\n' - # f' "reflection": ["..."],\n' # TODO: this should be a prompt file with placeholders - f" \"tool_name\": \"{server_name}.{tool['name']}\",\n" + f' "tool_name": "{current_server_name}.{tool["name"]}",\n' f' "tool_args": !follow schema above\n' f"}}\n" ) - return prompt - def has_tool(self, tool_name: str) -> bool: - """Check if a tool is available""" + # for server in self.servers: + # if server.name in server_names: + # server_name = server.name + # prompt += f"### {server_name}\n" + # prompt += f"{server.description}\n" + # tools = server.get_tools() + + # for tool in tools: + # prompt += ( + # f"\n### {server_name}.{tool['name']}:\n" + # f"{tool['description']}\n\n" + # # f"#### Categories:\n" + # # f"* kind: MCP Server Tool\n" + # # f'* server: "{server_name}" ({server.description})\n\n' + # # f"#### Arguments:\n" + # ) + + # input_schema = ( + # json.dumps(tool["input_schema"]) if tool["input_schema"] else "" + # ) + + # prompt += f"#### Input schema for tool_args:\n{input_schema}\n" + + # prompt += "\n" + + # prompt += ( + # f"#### Usage:\n" + # f"{{\n" + # # f' "observations": ["..."],\n' # TODO: this should be a prompt file with placeholders + # f' "thoughts": ["..."],\n' + # # f' "reflection": ["..."],\n' # TODO: this should be a prompt file with placeholders + # f" \"tool_name\": \"{server_name}.{tool['name']}\",\n" + # f' "tool_args": !follow schema above\n' + # f"}}\n" + # ) + + # return prompt + + def has_tool(self, tool_name: str, project_name: str = "") -> bool: if "." not in tool_name: return False server_name_part, tool_name_part = tool_name.split(".") with self.__lock: - for server in self.servers: - if server.name == server_name_part: - return server.has_tool(tool_name_part) + # When project is active, use ONLY project servers (they are separate from global) + # This matches the UI design: "These are separate from global MCP servers" + if project_name and project_name in self.project_servers: + for server in self.project_servers[project_name]: + if server.name == server_name_part: + return server.has_tool(tool_name_part) + else: + for server in self.servers: + if server.name == server_name_part: + return server.has_tool(tool_name_part) return False def get_tool(self, agent: Any, tool_name: str) -> MCPTool | None: - if not self.has_tool(tool_name): + from python.helpers import projects + + project_name = "" + try: + project_name = projects.get_context_project_name(agent.context) or "" + except Exception: + pass + + if not self.has_tool(tool_name, project_name): return None - return MCPTool(agent=agent, name=tool_name, method=None, args={}, message="", loop_data=None) + return MCPTool( + agent=agent, + name=tool_name, + method=None, + args={}, + message="", + loop_data=None, + ) async def call_tool( - self, tool_name: str, input_data: Dict[str, Any] + self, tool_name: str, input_data: Dict[str, Any], project_name: str = "" ) -> CallToolResult: - """Call a tool with the given input data""" if "." not in tool_name: raise ValueError(f"Tool {tool_name} not found") server_name_part, tool_name_part = tool_name.split(".") with self.__lock: - for server in self.servers: - if server.name == server_name_part and server.has_tool(tool_name_part): - return await server.call_tool(tool_name_part, input_data) + # When project is active, use ONLY project servers (they are separate from global) + # This matches the UI design: "These are separate from global MCP servers" + if project_name and project_name in self.project_servers: + for server in self.project_servers[project_name]: + if server.name == server_name_part and server.has_tool( + tool_name_part + ): + return await server.call_tool(tool_name_part, input_data) + else: + for server in self.servers: + if server.name == server_name_part and server.has_tool( + tool_name_part + ): + return await server.call_tool(tool_name_part, input_data) raise ValueError(f"Tool {tool_name} not found") @@ -823,7 +959,6 @@ async def _execute_with_session( try: async with AsyncExitStack() as temp_stack: try: - stdio, write = await self._create_stdio_transport(temp_stack) # PrintStyle(font_color="cyan").print(f"MCPClientBase ({self.server.name} - {operation_name}): Transport created. Initializing session...") session = await temp_stack.enter_async_context( @@ -883,6 +1018,7 @@ async def list_tools_op(current_session: ClientSession): "input_schema": tool.inputSchema, } for tool in response.tools + if tool.name not in self.server.disabled_tools ] PrintStyle(font_color="green").print( f"MCPClientBase ({self.server.name}): Tools updated. Found {len(self.tools)} tools." @@ -926,11 +1062,14 @@ async def call_tool( self, tool_name: str, input_data: Dict[str, Any] ) -> CallToolResult: # PrintStyle(font_color="cyan").print(f"MCPClientBase ({self.server.name}): Preparing for 'call_tool' operation for tool '{tool_name}'.") + if tool_name in self.server.disabled_tools: + raise PermissionError(f"Tool '{tool_name}' is disabled by configuration.") + if not self.has_tool(tool_name): PrintStyle(font_color="orange").print( f"MCPClientBase ({self.server.name}): Tool '{tool_name}' not in cache for 'call_tool', refreshing tools..." ) - await self.update_tools() # This will use its own properly managed session + await self.update_tools() if not self.has_tool(tool_name): PrintStyle(font_color="red").print( f"MCPClientBase ({self.server.name}): Tool '{tool_name}' not found after refresh. Raising ValueError." @@ -944,19 +1083,16 @@ async def call_tool( async def call_tool_op(current_session: ClientSession): set = settings.get_settings() - # PrintStyle(font_color="cyan").print(f"MCPClientBase ({self.server.name}): Executing 'call_tool' for '{tool_name}' via MCP session...") response: CallToolResult = await current_session.call_tool( tool_name, input_data, read_timeout_seconds=timedelta(seconds=set["mcp_client_tool_timeout"]), ) - # PrintStyle(font_color="green").print(f"MCPClientBase ({self.server.name}): Tool '{tool_name}' call successful via session.") return response try: return await self._execute_with_session(call_tool_op) except Exception as e: - # Error logged by _execute_with_session. Re-raise a specific error for the caller. PrintStyle( background_color="#AA4455", font_color="white", padding=True ).print( @@ -967,7 +1103,6 @@ async def call_tool_op(current_session: ClientSession): ) def get_log(self): - # read and return lines from self.log_file, do not close it if not hasattr(self, "log_file") or self.log_file is None: return "" self.log_file.seek(0) @@ -980,7 +1115,6 @@ def get_log(self): class MCPClientLocal(MCPClientBase): def __del__(self): - # close the log file if it exists if hasattr(self, "log_file") and self.log_file is not None: try: self.log_file.close() @@ -994,7 +1128,6 @@ async def _create_stdio_transport( MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], ]: - """Connect to an MCP server, init client and save stdio/write streams""" server: MCPServerLocal = cast(MCPServerLocal, self.server) if not server.command: @@ -1009,20 +1142,17 @@ async def _create_stdio_transport( encoding=server.encoding, encoding_error_handler=server.encoding_error_handler, ) - # create a custom error log handler that will capture error output import tempfile - # use a temporary file for error logging (text mode) if not already present if not hasattr(self, "log_file") or self.log_file is None: self.log_file = tempfile.TemporaryFile(mode="w+", encoding="utf-8") - # use the stdio_client with our error log file stdio_transport = await current_exit_stack.enter_async_context( stdio_client(server_params, errlog=self.log_file) ) - # do not read or close the file here, as stdio is async return stdio_transport + class CustomHTTPClientFactory(ABC): def __init__(self, verify: bool = True): self.verify = verify @@ -1033,32 +1163,25 @@ def __call__( timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None, ) -> httpx.AsyncClient: - # Set MCP defaults kwargs: dict[str, Any] = { "follow_redirects": True, } - - # Handle timeout if timeout is None: kwargs["timeout"] = httpx.Timeout(30.0) else: kwargs["timeout"] = timeout - - # Handle headers if headers is not None: kwargs["headers"] = headers - - # Handle authentication if auth is not None: kwargs["auth"] = auth return httpx.AsyncClient(**kwargs, verify=self.verify) -class MCPClientRemote(MCPClientBase): +class MCPClientRemote(MCPClientBase): def __init__(self, server: Union[MCPServerLocal, MCPServerRemote]): super().__init__(server) - self.session_id: Optional[str] = None # Track session ID for streaming HTTP clients + self.session_id: Optional[str] = None self.session_id_callback: Optional[Callable[[], Optional[str]]] = None async def _create_stdio_transport( @@ -1067,18 +1190,14 @@ async def _create_stdio_transport( MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage], ]: - """Connect to an MCP server, init client and save stdio/write streams""" server: MCPServerRemote = cast(MCPServerRemote, self.server) set = settings.get_settings() - # Use lower timeouts for faster failure detection init_timeout = min(server.init_timeout or set["mcp_client_init_timeout"], 5) tool_timeout = min(server.tool_timeout or set["mcp_client_tool_timeout"], 10) client_factory = CustomHTTPClientFactory(verify=server.verify) - # Check if this is a streaming HTTP type if _is_streaming_http_type(server.type): - # Use streamable HTTP client transport_result = await current_exit_stack.enter_async_context( streamablehttp_client( url=server.url, @@ -1088,15 +1207,10 @@ async def _create_stdio_transport( httpx_client_factory=client_factory, ) ) - # streamablehttp_client returns (read_stream, write_stream, get_session_id_callback) read_stream, write_stream, get_session_id_callback = transport_result - - # Store session ID callback for potential future use self.session_id_callback = get_session_id_callback - return read_stream, write_stream else: - # Use traditional SSE client (default behavior) stdio_transport = await current_exit_stack.enter_async_context( sse_client( url=server.url, @@ -1109,7 +1223,6 @@ async def _create_stdio_transport( return stdio_transport def get_session_id(self) -> Optional[str]: - """Get the current session ID if available (for streaming HTTP clients).""" if self.session_id_callback is not None: return self.session_id_callback() return None diff --git a/python/helpers/mcp_server.py b/python/helpers/mcp_server.py index 3c0308ed9c..255fc75647 100644 --- a/python/helpers/mcp_server.py +++ b/python/helpers/mcp_server.py @@ -43,6 +43,7 @@ class ToolResponse(BaseModel): description="The response from the remote Agent Zero Instance" ) chat_id: str = Field(description="The id of the chat this message belongs to.") + disabled_tools: list[str] = Field(default_factory=list) class ToolError(BaseModel): diff --git a/python/helpers/projects.py b/python/helpers/projects.py index 6e25738c6e..5721940f6e 100644 --- a/python/helpers/projects.py +++ b/python/helpers/projects.py @@ -25,9 +25,11 @@ class FileStructureInjectionSettings(TypedDict): max_lines: int gitignore: str + class SubAgentSettings(TypedDict): enabled: bool - + + class BasicProjectData(TypedDict): title: str description: str @@ -38,6 +40,7 @@ class BasicProjectData(TypedDict): ] # in the future we can add cutom and point to another existing folder file_structure: FileStructureInjectionSettings + class EditProjectData(BasicProjectData): name: str instruction_files_count: int @@ -47,7 +50,6 @@ class EditProjectData(BasicProjectData): subagents: dict[str, SubAgentSettings] - def get_projects_parent_folder(): return files.get_abs_path(PROJECTS_PARENT_DIR) @@ -61,9 +63,12 @@ def get_project_meta_folder(name: str, *sub_dirs: str): def delete_project(name: str): + from python.helpers.mcp_handler import MCPConfig + abs_path = files.get_abs_path(PROJECTS_PARENT_DIR, name) files.delete_dir(abs_path) deactivate_project_in_chats(name) + MCPConfig.get_instance().unload_project_servers(name) return name @@ -231,6 +236,7 @@ def _get_projects_list(parent_dir): def activate_project(context_id: str, name: str): from agent import AgentContext + from python.helpers.mcp_handler import MCPConfig data = load_edit_project_data(name) context = AgentContext.get(context_id) @@ -244,7 +250,9 @@ def activate_project(context_id: str, name: str): {"name": name, "title": display_name, "color": data.get("color", "")}, ) - # persist + mcp_config_str = load_project_mcp_servers(name) + MCPConfig.get_instance().load_project_servers(name, mcp_config_str) + persist_chat.save_tmp_chat(context) @@ -340,7 +348,7 @@ def save_project_subagents(name: str, subagents_data: dict[str, SubAgentSettings def _normalize_subagents( - subagents_data: dict[str, SubAgentSettings] + subagents_data: dict[str, SubAgentSettings], ) -> dict[str, SubAgentSettings]: from python.helpers import subagents @@ -405,25 +413,85 @@ def get_knowledge_files_count(name: str): ) return len(files.list_files_in_dir_recursively(knowledge_folder)) -def get_file_structure(name: str, basic_data: BasicProjectData|None=None) -> str: + +MCP_SERVERS_FILE = "mcp_servers.json" + + +def load_project_mcp_servers(name: str) -> str: + try: + abs_path = files.get_abs_path(get_project_meta_folder(name), MCP_SERVERS_FILE) + return files.read_file(abs_path) + except Exception: + return '{\n "mcpServers": {}\n}' + + +def save_project_mcp_servers(name: str, mcp_servers_json: str): + abs_path = files.get_abs_path(get_project_meta_folder(name), MCP_SERVERS_FILE) + files.write_file(abs_path, mcp_servers_json) + reload_project_mcp_servers(name) + + +def reload_project_mcp_servers(name: str): + import threading + + def _reload(): + from python.helpers.mcp_handler import MCPConfig + + mcp_config_str = load_project_mcp_servers(name) + MCPConfig.get_instance().load_project_servers(name, mcp_config_str) + + thread = threading.Thread(target=_reload, daemon=True) + thread.start() + + +def get_project_mcp_servers_parsed(name: str) -> dict: + try: + raw = load_project_mcp_servers(name) + return dirty_json.parse(raw) or {"mcpServers": {}} + except Exception: + return {"mcpServers": {}} + + +def import_global_mcp_server_to_project(project_name: str, server_name: str) -> dict: + from python.helpers.mcp_handler import MCPConfig + from python.helpers.settings import get_settings + + global_config_str = get_settings().get("mcp_servers", '{"mcpServers": {}}') + global_config = dirty_json.parse(global_config_str) or {"mcpServers": {}} + + global_servers = global_config.get("mcpServers", {}) + if server_name not in global_servers: + raise ValueError(f"Server '{server_name}' not found in global MCP config") + + project_config = get_project_mcp_servers_parsed(project_name) + if "mcpServers" not in project_config: + project_config["mcpServers"] = {} + + project_config["mcpServers"][server_name] = global_servers[server_name].copy() + + save_project_mcp_servers(project_name, dirty_json.stringify(project_config)) + return project_config + + +def get_file_structure(name: str, basic_data: BasicProjectData | None = None) -> str: project_folder = get_project_folder(name) if basic_data is None: basic_data = load_basic_project_data(name) - - tree = str(file_tree.file_tree( - project_folder, - max_depth=basic_data["file_structure"]["max_depth"], - max_files=basic_data["file_structure"]["max_files"], - max_folders=basic_data["file_structure"]["max_folders"], - max_lines=basic_data["file_structure"]["max_lines"], - ignore=basic_data["file_structure"]["gitignore"], - output_mode=file_tree.OUTPUT_MODE_STRING - )) + + tree = str( + file_tree.file_tree( + project_folder, + max_depth=basic_data["file_structure"]["max_depth"], + max_files=basic_data["file_structure"]["max_files"], + max_folders=basic_data["file_structure"]["max_folders"], + max_lines=basic_data["file_structure"]["max_lines"], + ignore=basic_data["file_structure"]["gitignore"], + output_mode=file_tree.OUTPUT_MODE_STRING, + ) + ) # empty? if "\n" not in tree: tree += "\n # Empty" return tree - - \ No newline at end of file diff --git a/webui/components/projects/project-edit-mcp-tools.html b/webui/components/projects/project-edit-mcp-tools.html new file mode 100644 index 0000000000..864142af09 --- /dev/null +++ b/webui/components/projects/project-edit-mcp-tools.html @@ -0,0 +1,320 @@ + +
++ Configure MCP servers for this project. These are separate from global MCP servers. +
+ +