-
Notifications
You must be signed in to change notification settings - Fork 1
refactor: remove caching of MCP metadata in MultiServerMCPClient #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -29,9 +29,6 @@ | |||||
| from language_model_gateway.gateway.mcp.exceptions.mcp_tool_unknown_exception import ( | ||||||
| McpToolUnknownException, | ||||||
| ) | ||||||
| from language_model_gateway.gateway.utilities.cache.mcp_tools_expiring_cache import ( | ||||||
| McpToolsMetadataExpiringCache, | ||||||
| ) | ||||||
| from mcp.types import Tool as MCPTool | ||||||
|
|
||||||
| from language_model_gateway.gateway.utilities.logger.log_levels import SRC_LOG_LEVELS | ||||||
|
|
@@ -44,11 +41,10 @@ | |||||
|
|
||||||
|
|
||||||
| class MultiServerMCPClientWithCaching(MultiServerMCPClient): | ||||||
| """A MultiServerMCPClient that caches tool metadata to avoid repeated calls to the MCP server. | ||||||
| """A MultiServerMCPClient that loads tool metadata from the MCP server every time. | ||||||
|
|
||||||
| This class extends the MultiServerMCPClient to cache the metadata of tools | ||||||
| across multiple calls to `get_tools`. It allows for efficient retrieval of tools | ||||||
| without needing to repeatedly query the MCP server for the same tool metadata. | ||||||
| This class extends the MultiServerMCPClient to always fetch the metadata of tools | ||||||
| across multiple calls to `get_tools`. It does not cache tool metadata. | ||||||
| """ | ||||||
|
|
||||||
| _identifier: UUID = uuid4() | ||||||
|
|
@@ -57,8 +53,7 @@ class MultiServerMCPClientWithCaching(MultiServerMCPClient): | |||||
| def __init__( | ||||||
| self, | ||||||
| *, | ||||||
| connections: Optional[dict[str, Connection]] = None, | ||||||
| cache: McpToolsMetadataExpiringCache, | ||||||
| connections: Optional[Dict[str, Connection]] = None, | ||||||
| tool_names: List[str] | None, | ||||||
| tool_output_token_limit: int | None, | ||||||
| token_reducer: TokenReducer, | ||||||
|
|
@@ -67,23 +62,15 @@ def __init__( | |||||
| Initialize the async config reader | ||||||
|
|
||||||
| Args: | ||||||
| cache: Expiring cache for model configurations | ||||||
| connections: Optional dictionary of server name to connection config. | ||||||
| If None, an empty dictionary will be used (default). | ||||||
| tool_names: Optional list of tool names to filter the tools. | ||||||
| If None, all tools will be returned. | ||||||
| tool_output_token_limit: Optional limit for the number of tokens | ||||||
| token_reducer: TokenReducer instance to manage token limits | ||||||
| """ | ||||||
| if cache is None: | ||||||
| raise ValueError("cache must not be None") | ||||||
| self._cache: McpToolsMetadataExpiringCache = cache | ||||||
| self._tool_names: List[str] | None = tool_names | ||||||
| self._tool_output_token_limit: int | None = tool_output_token_limit | ||||||
| if not isinstance(self._cache, McpToolsMetadataExpiringCache): | ||||||
| raise TypeError( | ||||||
| f"self._cache must be McpToolsMetadataExpiringCache, got {type(self._cache)}" | ||||||
| ) | ||||||
| if not isinstance(token_reducer, TokenReducer): | ||||||
| raise TypeError( | ||||||
| f"token_reducer must be TokenReducer, got {type(token_reducer)}" | ||||||
|
|
@@ -93,130 +80,89 @@ def __init__( | |||||
|
|
||||||
| async def load_tools_metadata_cache( | ||||||
| self, *, server_name: str | None = None, tool_names: List[str] | None | ||||||
| ) -> None: | ||||||
| """Get a list of all tools from all connected servers. | ||||||
| ) -> Dict[str, List[Tool]]: | ||||||
| """Get a list of all tools from all connected servers, always loading metadata. | ||||||
|
|
||||||
| Args: | ||||||
| server_name: Optional name of the server to get tools from. | ||||||
| If None, all tools from all servers will be returned (default). | ||||||
| tool_names: Optional list of tool names to filter the tools. | ||||||
|
|
||||||
| NOTE: a new session will be created for each tool call | ||||||
|
|
||||||
| Returns: | ||||||
| A list of LangChain tools | ||||||
|
|
||||||
| Dictionary mapping connection URLs to lists of Tool objects. | ||||||
| """ | ||||||
| async with self._lock: | ||||||
| cache: Dict[str, List[Tool]] | None = await self._cache.get() | ||||||
| if cache is None: | ||||||
| cache = await self._cache.create() | ||||||
| if cache is None: | ||||||
| raise RuntimeError("Cache must be initialized before getting tools") | ||||||
|
|
||||||
| tools_metadata: Dict[str, List[Tool]] = {} | ||||||
| if server_name is not None: | ||||||
| if server_name not in self.connections: | ||||||
| msg = f"Couldn't find a server with name '{server_name}', expected one of '{list(self.connections.keys())}'" | ||||||
| raise ValueError(msg) | ||||||
| connection_for_server: StreamableHttpConnection = cast( | ||||||
| StreamableHttpConnection, self.connections[server_name] | ||||||
| ) | ||||||
| if connection_for_server["url"] not in cache: | ||||||
| cache[ | ||||||
| connection_for_server["url"] | ||||||
| ] = await self.load_metadata_for_mcp_tools( | ||||||
| session=None, | ||||||
| connection=connection_for_server, | ||||||
| tool_names=tool_names, | ||||||
| ) | ||||||
| logger.info( | ||||||
| f"Loaded tools for connection {connection_for_server['url']}" | ||||||
| ) | ||||||
| else: | ||||||
| logger.debug( | ||||||
| f"Tools for connection {connection_for_server['url']} are already cached" | ||||||
| ) | ||||||
| tools_metadata[ | ||||||
| connection_for_server["url"] | ||||||
| ] = await self.load_metadata_for_mcp_tools( | ||||||
| session=None, | ||||||
| connection=connection_for_server, | ||||||
| tool_names=tool_names, | ||||||
| ) | ||||||
| logger.info( | ||||||
| f"Loaded tools for connection {connection_for_server['url']}" | ||||||
| ) | ||||||
| else: | ||||||
| connection: StreamableHttpConnection | ||||||
| for connection in [ | ||||||
| cast(StreamableHttpConnection, c) for c in self.connections.values() | ||||||
| ]: | ||||||
| # if the tools for this connection are already cached, skip loading them | ||||||
| if connection["url"] not in cache: | ||||||
| cache[ | ||||||
| connection["url"] | ||||||
| ] = await self.load_metadata_for_mcp_tools( | ||||||
| session=None, | ||||||
| connection=connection, | ||||||
| tool_names=self._tool_names, | ||||||
| ) | ||||||
| logger.info(f"Loaded tools for connection {connection['url']}") | ||||||
| else: | ||||||
| # see if we are missing any tools in the cache | ||||||
| if self._tool_names: | ||||||
| cached_tool_names = [ | ||||||
| tool.name for tool in cache[connection["url"]] | ||||||
| ] | ||||||
| missing_tools = set(self._tool_names) - set( | ||||||
| cached_tool_names | ||||||
| ) | ||||||
| if missing_tools: | ||||||
| logger.info( | ||||||
| f"Missing tools {missing_tools} for connection {connection['url']}, loading them" | ||||||
| ) | ||||||
| tools = await self.load_metadata_for_mcp_tools( | ||||||
| session=None, | ||||||
| connection=connection, | ||||||
| tool_names=list(missing_tools), | ||||||
| ) | ||||||
| cache[connection["url"]].extend(tools) | ||||||
| else: | ||||||
| logger.debug( | ||||||
| f"Tools for connection {connection['url']} are already cached and all tools are present" | ||||||
| ) | ||||||
| else: | ||||||
| logger.debug( | ||||||
| f"Tools for connection {connection['url']} are already cached" | ||||||
| ) | ||||||
| # set the cache with the loaded tools | ||||||
| await self._cache.set(cache) | ||||||
| tools_metadata[ | ||||||
| connection["url"] | ||||||
| ] = await self.load_metadata_for_mcp_tools( | ||||||
| session=None, | ||||||
| connection=connection, | ||||||
| tool_names=self._tool_names, | ||||||
| ) | ||||||
| logger.info(f"Loaded tools for connection {connection['url']}") | ||||||
| return tools_metadata | ||||||
imranq2 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
|
|
||||||
| @override | ||||||
| async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]: | ||||||
| """Get a list of all tools from all connected servers. | ||||||
|
|
||||||
| Args: | ||||||
| server_name: Optional name of the server to get tools from. | ||||||
| If None, all tools from all servers will be returned (default). | ||||||
|
|
||||||
| NOTE: a new session will be created for each tool call | ||||||
|
|
||||||
| Returns: | ||||||
| A list of LangChain tools | ||||||
|
|
||||||
| """ | ||||||
|
|
||||||
| await self.load_tools_metadata_cache( | ||||||
| server_name=server_name, tool_names=self._tool_names | ||||||
| ) | ||||||
| async with self._lock: | ||||||
| cache: Dict[str, List[Tool]] | None = await self._cache.get() | ||||||
| if cache is None: | ||||||
| raise RuntimeError("Cache must be initialized before getting tools") | ||||||
|
|
||||||
| # create LangChain tools from the loaded MCP tools | ||||||
| all_tools: List[BaseTool] = [] | ||||||
| connection: StreamableHttpConnection | ||||||
| """Get a list of all tools from all connected servers, always loading metadata and never storing it.""" | ||||||
| all_tools: List[BaseTool] = [] | ||||||
| if server_name is not None: | ||||||
| if server_name not in self.connections: | ||||||
| msg = f"Couldn't find a server with name '{server_name}', expected one of '{list(self.connections.keys())}'" | ||||||
| raise ValueError(msg) | ||||||
| connection_for_server: StreamableHttpConnection = cast( | ||||||
| StreamableHttpConnection, self.connections[server_name] | ||||||
| ) | ||||||
| tools_for_connection: List[Tool] = await self.load_metadata_for_mcp_tools( | ||||||
| session=None, | ||||||
| connection=connection_for_server, | ||||||
| tool_names=self._tool_names, | ||||||
| ) | ||||||
| all_tools.extend( | ||||||
| self.create_tools_from_list( | ||||||
| tools=tools_for_connection, | ||||||
| session=None, | ||||||
| connection=connection_for_server, | ||||||
| ) | ||||||
| ) | ||||||
| else: | ||||||
| for connection in [ | ||||||
| cast(StreamableHttpConnection, c) for c in self.connections.values() | ||||||
| ]: | ||||||
| tools_for_connection: List[Tool] = cache[connection["url"]] | ||||||
| tools_for_connection = await self.load_metadata_for_mcp_tools( | ||||||
| session=None, | ||||||
| connection=connection, | ||||||
| tool_names=self._tool_names, | ||||||
| ) | ||||||
| all_tools.extend( | ||||||
| self.create_tools_from_list( | ||||||
| tools=tools_for_connection, session=None, connection=connection | ||||||
| ) | ||||||
| ) | ||||||
| return all_tools | ||||||
| return all_tools | ||||||
imranq2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
| @staticmethod | ||||||
| async def load_metadata_for_mcp_tools( | ||||||
|
|
@@ -449,7 +395,7 @@ def create_tools_from_list( | |||||
| self, | ||||||
| *, | ||||||
| tools: list[Tool], | ||||||
| session: ClientSession | None = None, | ||||||
| session: ClientSession | None, | ||||||
|
||||||
| session: ClientSession | None, | |
| session: ClientSession | None = None, |
Uh oh!
There was an error while loading. Please reload this page.