Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions language_model_gateway/container/container_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@
from language_model_gateway.gateway.utilities.cache.config_expiring_cache import (
ConfigExpiringCache,
)
from language_model_gateway.gateway.utilities.cache.mcp_tools_expiring_cache import (
McpToolsMetadataExpiringCache,
)
from language_model_gateway.gateway.utilities.confluence.confluence_helper import (
ConfluenceHelper,
)
Expand Down Expand Up @@ -127,17 +124,6 @@ def create_container(cls, *, source: str) -> SimpleContainer:
)
),
)
container.singleton(
McpToolsMetadataExpiringCache,
lambda c: McpToolsMetadataExpiringCache(
ttl_seconds=(
int(os.environ["MCP_TOOLS_METADATA_CACHE_TIMEOUT_SECONDS"])
if os.environ.get("MCP_TOOLS_METADATA_CACHE_TIMEOUT_SECONDS")
else 60 * 60
),
init_value={},
),
)

container.singleton(HttpClientFactory, lambda c: HttpClientFactory())

Expand Down Expand Up @@ -266,7 +252,6 @@ def create_container(cls, *, source: str) -> SimpleContainer:
container.singleton(
MCPToolProvider,
lambda c: MCPToolProvider(
cache=c.resolve(McpToolsMetadataExpiringCache),
tool_auth_manager=c.resolve(ToolAuthManager),
environment_variables=c.resolve(
LanguageModelGatewayEnvironmentVariables
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
from language_model_gateway.gateway.mcp.exceptions.mcp_tool_unknown_exception import (
McpToolUnknownException,
)
from language_model_gateway.gateway.utilities.cache.mcp_tools_expiring_cache import (
McpToolsMetadataExpiringCache,
)
from mcp.types import Tool as MCPTool

from language_model_gateway.gateway.utilities.logger.log_levels import SRC_LOG_LEVELS
Expand All @@ -44,11 +41,10 @@


class MultiServerMCPClientWithCaching(MultiServerMCPClient):
"""A MultiServerMCPClient that caches tool metadata to avoid repeated calls to the MCP server.
"""A MultiServerMCPClient that loads tool metadata from the MCP server every time.

This class extends the MultiServerMCPClient to cache the metadata of tools
across multiple calls to `get_tools`. It allows for efficient retrieval of tools
without needing to repeatedly query the MCP server for the same tool metadata.
This class extends the MultiServerMCPClient to always fetch the metadata of tools
across multiple calls to `get_tools`. It does not cache tool metadata.
"""

_identifier: UUID = uuid4()
Expand All @@ -57,8 +53,7 @@ class MultiServerMCPClientWithCaching(MultiServerMCPClient):
def __init__(
self,
*,
connections: Optional[dict[str, Connection]] = None,
cache: McpToolsMetadataExpiringCache,
connections: Optional[Dict[str, Connection]] = None,
tool_names: List[str] | None,
tool_output_token_limit: int | None,
token_reducer: TokenReducer,
Expand All @@ -67,23 +62,15 @@ def __init__(
Initialize the async config reader

Args:
cache: Expiring cache for model configurations
connections: Optional dictionary of server name to connection config.
If None, an empty dictionary will be used (default).
tool_names: Optional list of tool names to filter the tools.
If None, all tools will be returned.
tool_output_token_limit: Optional limit for the number of tokens
token_reducer: TokenReducer instance to manage token limits
"""
if cache is None:
raise ValueError("cache must not be None")
self._cache: McpToolsMetadataExpiringCache = cache
self._tool_names: List[str] | None = tool_names
self._tool_output_token_limit: int | None = tool_output_token_limit
if not isinstance(self._cache, McpToolsMetadataExpiringCache):
raise TypeError(
f"self._cache must be McpToolsMetadataExpiringCache, got {type(self._cache)}"
)
if not isinstance(token_reducer, TokenReducer):
raise TypeError(
f"token_reducer must be TokenReducer, got {type(token_reducer)}"
Expand All @@ -93,130 +80,89 @@ def __init__(

async def load_tools_metadata_cache(
self, *, server_name: str | None = None, tool_names: List[str] | None
) -> None:
"""Get a list of all tools from all connected servers.
) -> Dict[str, List[Tool]]:
"""Get a list of all tools from all connected servers, always loading metadata.

Args:
server_name: Optional name of the server to get tools from.
If None, all tools from all servers will be returned (default).
tool_names: Optional list of tool names to filter the tools.

NOTE: a new session will be created for each tool call

Returns:
A list of LangChain tools

Dictionary mapping connection URLs to lists of Tool objects.
"""
async with self._lock:
cache: Dict[str, List[Tool]] | None = await self._cache.get()
if cache is None:
cache = await self._cache.create()
if cache is None:
raise RuntimeError("Cache must be initialized before getting tools")

tools_metadata: Dict[str, List[Tool]] = {}
if server_name is not None:
if server_name not in self.connections:
msg = f"Couldn't find a server with name '{server_name}', expected one of '{list(self.connections.keys())}'"
raise ValueError(msg)
connection_for_server: StreamableHttpConnection = cast(
StreamableHttpConnection, self.connections[server_name]
)
if connection_for_server["url"] not in cache:
cache[
connection_for_server["url"]
] = await self.load_metadata_for_mcp_tools(
session=None,
connection=connection_for_server,
tool_names=tool_names,
)
logger.info(
f"Loaded tools for connection {connection_for_server['url']}"
)
else:
logger.debug(
f"Tools for connection {connection_for_server['url']} are already cached"
)
tools_metadata[
connection_for_server["url"]
] = await self.load_metadata_for_mcp_tools(
session=None,
connection=connection_for_server,
tool_names=tool_names,
)
logger.info(
f"Loaded tools for connection {connection_for_server['url']}"
)
else:
connection: StreamableHttpConnection
for connection in [
cast(StreamableHttpConnection, c) for c in self.connections.values()
]:
# if the tools for this connection are already cached, skip loading them
if connection["url"] not in cache:
cache[
connection["url"]
] = await self.load_metadata_for_mcp_tools(
session=None,
connection=connection,
tool_names=self._tool_names,
)
logger.info(f"Loaded tools for connection {connection['url']}")
else:
# see if we are missing any tools in the cache
if self._tool_names:
cached_tool_names = [
tool.name for tool in cache[connection["url"]]
]
missing_tools = set(self._tool_names) - set(
cached_tool_names
)
if missing_tools:
logger.info(
f"Missing tools {missing_tools} for connection {connection['url']}, loading them"
)
tools = await self.load_metadata_for_mcp_tools(
session=None,
connection=connection,
tool_names=list(missing_tools),
)
cache[connection["url"]].extend(tools)
else:
logger.debug(
f"Tools for connection {connection['url']} are already cached and all tools are present"
)
else:
logger.debug(
f"Tools for connection {connection['url']} are already cached"
)
# set the cache with the loaded tools
await self._cache.set(cache)
tools_metadata[
connection["url"]
] = await self.load_metadata_for_mcp_tools(
session=None,
connection=connection,
tool_names=self._tool_names,
)
logger.info(f"Loaded tools for connection {connection['url']}")
return tools_metadata

@override
async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
"""Get a list of all tools from all connected servers.

Args:
server_name: Optional name of the server to get tools from.
If None, all tools from all servers will be returned (default).

NOTE: a new session will be created for each tool call

Returns:
A list of LangChain tools

"""

await self.load_tools_metadata_cache(
server_name=server_name, tool_names=self._tool_names
)
async with self._lock:
cache: Dict[str, List[Tool]] | None = await self._cache.get()
if cache is None:
raise RuntimeError("Cache must be initialized before getting tools")

# create LangChain tools from the loaded MCP tools
all_tools: List[BaseTool] = []
connection: StreamableHttpConnection
"""Get a list of all tools from all connected servers, always loading metadata and never storing it."""
all_tools: List[BaseTool] = []
if server_name is not None:
if server_name not in self.connections:
msg = f"Couldn't find a server with name '{server_name}', expected one of '{list(self.connections.keys())}'"
raise ValueError(msg)
connection_for_server: StreamableHttpConnection = cast(
StreamableHttpConnection, self.connections[server_name]
)
tools_for_connection: List[Tool] = await self.load_metadata_for_mcp_tools(
session=None,
connection=connection_for_server,
tool_names=self._tool_names,
)
all_tools.extend(
self.create_tools_from_list(
tools=tools_for_connection,
session=None,
connection=connection_for_server,
)
)
else:
for connection in [
cast(StreamableHttpConnection, c) for c in self.connections.values()
]:
tools_for_connection: List[Tool] = cache[connection["url"]]
tools_for_connection = await self.load_metadata_for_mcp_tools(
session=None,
connection=connection,
tool_names=self._tool_names,
)
all_tools.extend(
self.create_tools_from_list(
tools=tools_for_connection, session=None, connection=connection
)
)
return all_tools
return all_tools

@staticmethod
async def load_metadata_for_mcp_tools(
Expand Down Expand Up @@ -449,7 +395,7 @@ def create_tools_from_list(
self,
*,
tools: list[Tool],
session: ClientSession | None = None,
session: ClientSession | None,
Copy link

Copilot AI Nov 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removal of the default value (= None) from the session parameter changes the method signature. While this doesn't break existing code (since all current callers explicitly pass session=None), it does change the API by requiring callers to explicitly provide a value for session instead of being able to omit it. Consider whether this is intentional or if the default value should be preserved for backward compatibility.

Suggested change
session: ClientSession | None,
session: ClientSession | None = None,

Copilot uses AI. Check for mistakes.
connection: Optional[Connection] = None,
) -> List[BaseTool]:
"""
Expand Down
11 changes: 0 additions & 11 deletions language_model_gateway/gateway/tools/mcp_tool_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
from language_model_gateway.gateway.mcp.exceptions.mcp_tool_unauthorized_exception import (
McpToolUnauthorizedException,
)
from language_model_gateway.gateway.utilities.cache.mcp_tools_expiring_cache import (
McpToolsMetadataExpiringCache,
)
from language_model_gateway.gateway.utilities.language_model_gateway_environment_variables import (
LanguageModelGatewayEnvironmentVariables,
)
Expand All @@ -50,21 +47,14 @@ class MCPToolProvider:
def __init__(
self,
*,
cache: McpToolsMetadataExpiringCache,
tool_auth_manager: ToolAuthManager,
environment_variables: LanguageModelGatewayEnvironmentVariables,
token_reducer: TokenReducer,
) -> None:
"""
Initialize the MCPToolProvider with a cache.

Args:
cache: An ExpiringCache instance to store tools by their MCP URLs.
"""
self.tools_by_mcp_url: Dict[str, List[BaseTool]] = {}
self._cache: McpToolsMetadataExpiringCache = cache
if self._cache is None:
raise ValueError("Cache must be provided")

self.tool_auth_manager = tool_auth_manager
if self.tool_auth_manager is None:
Expand Down Expand Up @@ -214,7 +204,6 @@ async def get_tools_by_url_async(

tool_names: List[str] | None = tool.tools.split(",") if tool.tools else None
client: MultiServerMCPClientWithCaching = MultiServerMCPClientWithCaching(
cache=self._cache,
connections={
f"{tool.name}": mcp_tool_config,
},
Expand Down
Loading
Loading