diff --git a/src/fastmcp/server/middleware/caching.py b/src/fastmcp/server/middleware/caching.py index 670c30a44..0bb57578c 100644 --- a/src/fastmcp/server/middleware/caching.py +++ b/src/fastmcp/server/middleware/caching.py @@ -1,5 +1,6 @@ """A middleware for response caching.""" +import hashlib from collections.abc import Sequence from logging import Logger from typing import Any, TypedDict @@ -411,7 +412,7 @@ async def on_call_tool( ) is False or not self._matches_tool_cache_settings(tool_name=tool_name): return await call_next(context=context) - cache_key: str = f"{tool_name}:{_get_arguments_str(context.message.arguments)}" + cache_key: str = _make_call_tool_cache_key(msg=context.message) if cached_value := await self._call_tool_cache.get(key=cache_key): return cached_value.unwrap() @@ -440,7 +441,7 @@ async def on_read_resource( if self._read_resource_settings.get("enabled") is False: return await call_next(context=context) - cache_key: str = str(context.message.uri) + cache_key: str = _make_read_resource_cache_key(msg=context.message) cached_value: CachableResourceResult | None if cached_value := await self._read_resource_cache.get(key=cache_key): @@ -468,7 +469,7 @@ async def on_get_prompt( if self._get_prompt_settings.get("enabled") is False: return await call_next(context=context) - cache_key: str = f"{context.message.name}:{_get_arguments_str(arguments=context.message.arguments)}" + cache_key: str = _make_get_prompt_cache_key(msg=context.message) if cached_value := await self._get_prompt_cache.get(key=cache_key): return cached_value.unwrap() @@ -519,3 +520,27 @@ def _get_arguments_str(arguments: dict[str, Any] | None) -> str: except TypeError: return repr(arguments) + + +def _hash_cache_key(value: str) -> str: + """Build a fixed-length SHA-256 cache key from request-derived input.""" + + return hashlib.sha256(value.encode()).hexdigest() + + +def _make_call_tool_cache_key(msg: mcp.types.CallToolRequestParams) -> str: + """Make a cache key for a tool call using a stable hash of name and arguments.""" + + return _hash_cache_key(f"{msg.name}:{_get_arguments_str(msg.arguments)}") + + +def _make_read_resource_cache_key(msg: mcp.types.ReadResourceRequestParams) -> str: + """Make a cache key for a resource read using a stable hash of URI.""" + + return _hash_cache_key(str(msg.uri)) + + +def _make_get_prompt_cache_key(msg: mcp.types.GetPromptRequestParams) -> str: + """Make a cache key for a prompt get using a stable hash of name and arguments.""" + + return _hash_cache_key(f"{msg.name}:{_get_arguments_str(msg.arguments)}") diff --git a/tests/server/middleware/test_caching.py b/tests/server/middleware/test_caching.py index 858089251..027b29959 100644 --- a/tests/server/middleware/test_caching.py +++ b/tests/server/middleware/test_caching.py @@ -35,6 +35,9 @@ CallToolSettings, ResponseCachingMiddleware, ResponseCachingStatistics, + _make_call_tool_cache_key, + _make_get_prompt_cache_key, + _make_read_resource_cache_key, ) from fastmcp.server.middleware.middleware import CallNext, MiddlewareContext from fastmcp.tools.tool import Tool, ToolResult @@ -631,3 +634,40 @@ async def test_prefixed_tool_callable_after_cache_hit( result = await client.call_tool("child_add", {"a": 5, "b": 3}) assert not result.is_error assert tracking_calculator.add_calls == 1 + + +class TestCacheKeyGeneration: + def test_call_tool_key_is_hashed_and_does_not_include_raw_input(self): + msg = mcp.types.CallToolRequestParams( + name="toolX", + arguments={"password": "secret", "path": "../../etc/passwd"}, + ) + + key = _make_call_tool_cache_key(msg) + + assert len(key) == 64 + assert "secret" not in key + assert "../../etc/passwd" not in key + + def test_read_resource_key_is_hashed_and_does_not_include_raw_uri(self): + msg = mcp.types.ReadResourceRequestParams( + uri=AnyUrl("file:///tmp/../../etc/shadow?token=abcd") + ) + + key = _make_read_resource_cache_key(msg) + + assert len(key) == 64 + assert "shadow" not in key + assert "token=abcd" not in key + + def test_get_prompt_key_is_hashed_and_stable(self): + msg = mcp.types.GetPromptRequestParams( + name="promptY", + arguments={"api_key": "ABC123", "scope": "admin"}, + ) + + key = _make_get_prompt_cache_key(msg) + + assert len(key) == 64 + assert "ABC123" not in key + assert key == _make_get_prompt_cache_key(msg)