1+ from collections .abc import Awaitable , Callable
12from datetime import timedelta
2- from typing import Any , Protocol
3+ from typing import Any , Protocol , TypeAlias
34
45import anyio .lowlevel
56from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
7+ from jsonschema import ValidationError , validate
68from pydantic import AnyUrl , TypeAdapter
79
810import mcp .types as types
1113from mcp .shared .session import BaseSession , ProgressFnT , RequestResponder
1214from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
1315
16+
17+ class ToolOutputValidator :
18+ async def validate (
19+ self , request : types .CallToolRequest , result : types .CallToolResult
20+ ) -> bool :
21+ raise RuntimeError ("Not implemented" )
22+
23+
1424DEFAULT_CLIENT_INFO = types .Implementation (name = "mcp" , version = "0.1.0" )
1525
1626
@@ -77,6 +87,25 @@ async def _default_logging_callback(
7787 pass
7888
7989
90+ ToolOutputValidatorProvider : TypeAlias = Callable [
91+ ...,
92+ Awaitable [ToolOutputValidator ],
93+ ]
94+
95+
96+ # this bag of spanners is required in order to
97+ # enable the client session to be parsed to the validator
98+ async def _python_circularity_hell (arg : Any ) -> ToolOutputValidator :
99+ # in any sane version of the universe this should never happen
100+ # of course in any sane programming language class circularity
101+ # dependencies shouldn't be this hard to manage
102+ raise RuntimeError (
103+ "Help I'm stuck in python circularity hell, please send biscuits"
104+ )
105+
106+
107+ _default_tool_output_validator : ToolOutputValidatorProvider = _python_circularity_hell
108+
80109ClientResponse : TypeAdapter [types .ClientResult | types .ErrorData ] = TypeAdapter (
81110 types .ClientResult | types .ErrorData
82111)
@@ -101,6 +130,7 @@ def __init__(
101130 logging_callback : LoggingFnT | None = None ,
102131 message_handler : MessageHandlerFnT | None = None ,
103132 client_info : types .Implementation | None = None ,
133+ tool_output_validator_provider : ToolOutputValidatorProvider | None = None ,
104134 ) -> None :
105135 super ().__init__ (
106136 read_stream ,
@@ -114,6 +144,7 @@ def __init__(
114144 self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
115145 self ._logging_callback = logging_callback or _default_logging_callback
116146 self ._message_handler = message_handler or _default_message_handler
147+ self ._tool_output_validator_provider = tool_output_validator_provider
117148
118149 async def initialize (self ) -> types .InitializeResult :
119150 sampling = types .SamplingCapability ()
@@ -154,6 +185,11 @@ async def initialize(self) -> types.InitializeResult:
154185 )
155186 )
156187
188+ tool_output_validator_provider = (
189+ self ._tool_output_validator_provider or _default_tool_output_validator
190+ )
191+ self ._tool_output_validator = await tool_output_validator_provider (self )
192+
157193 return result
158194
159195 async def send_ping (self ) -> types .EmptyResult :
@@ -271,24 +307,33 @@ async def call_tool(
271307 arguments : dict [str , Any ] | None = None ,
272308 read_timeout_seconds : timedelta | None = None ,
273309 progress_callback : ProgressFnT | None = None ,
310+ validate_result : bool = True ,
274311 ) -> types .CallToolResult :
275312 """Send a tools/call request with optional progress callback support."""
276313
277- return await self .send_request (
278- types .ClientRequest (
279- types .CallToolRequest (
280- method = "tools/call" ,
281- params = types .CallToolRequestParams (
282- name = name ,
283- arguments = arguments ,
284- ),
285- )
314+ request = types .CallToolRequest (
315+ method = "tools/call" ,
316+ params = types .CallToolRequestParams (
317+ name = name ,
318+ arguments = arguments ,
286319 ),
320+ )
321+
322+ result = await self .send_request (
323+ types .ClientRequest (request ),
287324 types .CallToolResult ,
288325 request_read_timeout_seconds = read_timeout_seconds ,
289326 progress_callback = progress_callback ,
290327 )
291328
329+ if validate_result :
330+ valid = await self ._tool_output_validator .validate (request , result )
331+
332+ if not valid :
333+ raise RuntimeError ("Server responded with invalid result: " f"{ result } " )
334+ # not validating or is valid
335+ return result
336+
292337 async def list_prompts (self , cursor : str | None = None ) -> types .ListPromptsResult :
293338 """Send a prompts/list request."""
294339 return await self .send_request (
@@ -404,3 +449,67 @@ async def _received_notification(
404449 await self ._logging_callback (params )
405450 case _:
406451 pass
452+
453+
454+ class SimpleCachingToolOutputValidator (ToolOutputValidator ):
455+ _schema_cache : dict [str , dict [str , Any ] | bool ]
456+
457+ def __init__ (self , session : ClientSession ):
458+ self ._session = session
459+ self ._schema_cache = {}
460+ self ._refresh_cache = True
461+
462+ async def validate (
463+ self , request : types .CallToolRequest , result : types .CallToolResult
464+ ) -> bool :
465+ if result .isError :
466+ # allow errors to be propagated
467+ return True
468+ else :
469+ if self ._refresh_cache :
470+ await self ._refresh_schema_cache ()
471+
472+ schema = self ._schema_cache .get (request .params .name )
473+
474+ if schema is None :
475+ raise RuntimeError (f"Unknown tool { request .params .name } " )
476+ elif schema is False :
477+ # no schema
478+ # TODO add logging
479+ return result .structuredContent is None
480+ else :
481+ try :
482+ # TODO opportunity to build jsonschema.protocol.Validator
483+ # and reuse rather than build every time
484+ validate (result .structuredContent , schema )
485+ return True
486+ except ValidationError :
487+ # TODO log this
488+ return False
489+
490+ async def _refresh_schema_cache (self ):
491+ cursor = None
492+ first = True
493+ while first or cursor is not None :
494+ first = False
495+ tools_result = await self ._session .list_tools (cursor )
496+ for tool in tools_result .tools :
497+ # store a flag to be able to later distinguish between
498+ # no schema for tool and unknown tool which can't be verified
499+ schema_or_flag = (
500+ False if tool .outputSchema is None else tool .outputSchema
501+ )
502+ self ._schema_cache [tool .name ] = schema_or_flag
503+ cursor = tools_result .nextCursor
504+ continue
505+
506+ self ._refresh_cache = False
507+
508+
509+ async def _escape_from_circular_python_hell (
510+ session : ClientSession ,
511+ ) -> ToolOutputValidator :
512+ return SimpleCachingToolOutputValidator (session )
513+
514+
515+ _default_tool_output_validator = _escape_from_circular_python_hell
0 commit comments