Skip to content

Commit 91feb7d

Browse files
authored
[1/n] Add MCP types to the SDK (#320)
### Summary: 1. Add the MCP dep for python 3.10, since it doesn't support 3.9 and below 2. Create MCPServer, which is the agents SDK representation of an MCP server 3. Create implementations for HTTP-SSE and StdIO servers, directly copying the [MCP SDK example](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py) 4. Add a util to transform MCP tools into Agent SDK tools Note: I added optional caching support to the servers. That way, if you happen to know a server's tools don't change, you can just cache them. ### Test Plan: Checks pass. I added tests at the end of the stack. --- #324 #322 #321 -> #320 #319
2 parents 923a354 + 97e3dc3 commit 91feb7d

File tree

6 files changed

+572
-2
lines changed

6 files changed

+572
-2
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies = [
1313
"typing-extensions>=4.12.2, <5",
1414
"requests>=2.0, <3",
1515
"types-requests>=2.0, <3",
16+
"mcp; python_version >= '3.10'",
1617
]
1718
classifiers = [
1819
"Typing :: Typed",

src/agents/mcp/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
try:
2+
from .server import (
3+
MCPServer,
4+
MCPServerSse,
5+
MCPServerSseParams,
6+
MCPServerStdio,
7+
MCPServerStdioParams,
8+
)
9+
except ImportError:
10+
pass
11+
12+
from .util import MCPUtil
13+
14+
__all__ = [
15+
"MCPServer",
16+
"MCPServerSse",
17+
"MCPServerSseParams",
18+
"MCPServerStdio",
19+
"MCPServerStdioParams",
20+
"MCPUtil",
21+
]

src/agents/mcp/mcp_util.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import functools
2+
import json
3+
from typing import Any
4+
5+
from mcp.types import Tool as MCPTool
6+
7+
from .. import _debug
8+
from ..exceptions import AgentsException, ModelBehaviorError, UserError
9+
from ..logger import logger
10+
from ..run_context import RunContextWrapper
11+
from ..tool import FunctionTool, Tool
12+
from .server import MCPServer
13+
14+
15+
class MCPUtil:
16+
"""Set of utilities for interop between MCP and Agents SDK tools."""
17+
18+
@classmethod
19+
async def get_all_function_tools(cls, servers: list[MCPServer]) -> list[Tool]:
20+
"""Get all function tools from a list of MCP servers."""
21+
tools = []
22+
tool_names: set[str] = set()
23+
for server in servers:
24+
server_tools = await cls.get_function_tools(server)
25+
server_tool_names = {tool.name for tool in server_tools}
26+
if len(server_tool_names & tool_names) > 0:
27+
raise UserError(
28+
f"Duplicate tool names found across MCP servers: "
29+
f"{server_tool_names & tool_names}"
30+
)
31+
tool_names.update(server_tool_names)
32+
tools.extend(server_tools)
33+
34+
return tools
35+
36+
@classmethod
37+
async def get_function_tools(cls, server: MCPServer) -> list[Tool]:
38+
"""Get all function tools from a single MCP server."""
39+
tools = await server.list_tools()
40+
return [cls.to_function_tool(tool, server) for tool in tools]
41+
42+
@classmethod
43+
def to_function_tool(cls, tool: MCPTool, server: MCPServer) -> FunctionTool:
44+
"""Convert an MCP tool to an Agents SDK function tool."""
45+
invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool)
46+
return FunctionTool(
47+
name=tool.name,
48+
description=tool.description or "",
49+
params_json_schema=tool.inputSchema,
50+
on_invoke_tool=invoke_func,
51+
strict_json_schema=False,
52+
)
53+
54+
@classmethod
55+
async def invoke_mcp_tool(
56+
cls, server: MCPServer, tool: MCPTool, context: RunContextWrapper[Any], input_json: str
57+
) -> str:
58+
"""Invoke an MCP tool and return the result as a string."""
59+
try:
60+
json_data: dict[str, Any] = json.loads(input_json) if input_json else {}
61+
except Exception as e:
62+
if _debug.DONT_LOG_TOOL_DATA:
63+
logger.debug(f"Invalid JSON input for tool {tool.name}")
64+
else:
65+
logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}")
66+
raise ModelBehaviorError(
67+
f"Invalid JSON input for tool {tool.name}: {input_json}"
68+
) from e
69+
70+
if _debug.DONT_LOG_TOOL_DATA:
71+
logger.debug(f"Invoking MCP tool {tool.name}")
72+
else:
73+
logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}")
74+
75+
try:
76+
result = await server.call_tool(tool.name, json_data)
77+
except Exception as e:
78+
logger.error(f"Error invoking MCP tool {tool.name}: {e}")
79+
raise AgentsException(f"Error invoking MCP tool {tool.name}: {e}") from e
80+
81+
if _debug.DONT_LOG_TOOL_DATA:
82+
logger.debug(f"MCP tool {tool.name} completed.")
83+
else:
84+
logger.debug(f"MCP tool {tool.name} returned {result}")
85+
86+
# The MCP tool result is a list of content items, whereas OpenAI tool outputs are a single
87+
# string. We'll try to convert.
88+
if len(result.content) == 1:
89+
return result.content[0].model_dump_json()
90+
elif len(result.content) > 1:
91+
return json.dumps([item.model_dump() for item in result.content])
92+
else:
93+
logger.error(f"Errored MCP tool result: {result}")
94+
return "Error running tool."

src/agents/mcp/server.py

+269
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
from __future__ import annotations
2+
3+
import abc
4+
import asyncio
5+
from contextlib import AbstractAsyncContextManager, AsyncExitStack
6+
from pathlib import Path
7+
from typing import Any, Literal
8+
9+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10+
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
11+
from mcp.client.sse import sse_client
12+
from mcp.types import CallToolResult, JSONRPCMessage
13+
from typing_extensions import NotRequired, TypedDict
14+
15+
from ..exceptions import UserError
16+
from ..logger import logger
17+
18+
19+
class MCPServer(abc.ABC):
20+
"""Base class for Model Context Protocol servers."""
21+
22+
@abc.abstractmethod
23+
async def connect(self):
24+
"""Connect to the server. For example, this might mean spawning a subprocess or
25+
opening a network connection. The server is expected to remain connected until
26+
`cleanup()` is called.
27+
"""
28+
pass
29+
30+
@abc.abstractmethod
31+
async def cleanup(self):
32+
"""Cleanup the server. For example, this might mean closing a subprocess or
33+
closing a network connection.
34+
"""
35+
pass
36+
37+
@abc.abstractmethod
38+
async def list_tools(self) -> list[MCPTool]:
39+
"""List the tools available on the server."""
40+
pass
41+
42+
@abc.abstractmethod
43+
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
44+
"""Invoke a tool on the server."""
45+
pass
46+
47+
48+
class _MCPServerWithClientSession(MCPServer, abc.ABC):
49+
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
50+
51+
def __init__(self, cache_tools_list: bool):
52+
"""
53+
Args:
54+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
55+
cached and only fetched from the server once. If `False`, the tools list will be
56+
fetched from the server on each call to `list_tools()`. The cache can be invalidated
57+
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
58+
server will not change its tools list, because it can drastically improve latency
59+
(by avoiding a round-trip to the server every time).
60+
"""
61+
self.session: ClientSession | None = None
62+
self.exit_stack: AsyncExitStack = AsyncExitStack()
63+
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
64+
self.cache_tools_list = cache_tools_list
65+
66+
# The cache is always dirty at startup, so that we fetch tools at least once
67+
self._cache_dirty = True
68+
self._tools_list: list[MCPTool] | None = None
69+
70+
@abc.abstractmethod
71+
def create_streams(
72+
self,
73+
) -> AbstractAsyncContextManager[
74+
tuple[
75+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
76+
MemoryObjectSendStream[JSONRPCMessage],
77+
]
78+
]:
79+
"""Create the streams for the server."""
80+
pass
81+
82+
async def __aenter__(self):
83+
await self.connect()
84+
return self
85+
86+
async def __aexit__(self, exc_type, exc_value, traceback):
87+
await self.cleanup()
88+
89+
def invalidate_tools_cache(self):
90+
"""Invalidate the tools cache."""
91+
self._cache_dirty = True
92+
93+
async def connect(self):
94+
"""Connect to the server."""
95+
try:
96+
transport = await self.exit_stack.enter_async_context(self.create_streams())
97+
read, write = transport
98+
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
99+
await session.initialize()
100+
self.session = session
101+
except Exception as e:
102+
logger.error(f"Error initializing MCP server: {e}")
103+
await self.cleanup()
104+
raise
105+
106+
async def list_tools(self) -> list[MCPTool]:
107+
"""List the tools available on the server."""
108+
if not self.session:
109+
raise UserError("Server not initialized. Make sure you call `connect()` first.")
110+
111+
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
112+
if self.cache_tools_list and not self._cache_dirty and self._tools_list:
113+
return self._tools_list
114+
115+
# Reset the cache dirty to False
116+
self._cache_dirty = False
117+
118+
# Fetch the tools from the server
119+
self._tools_list = (await self.session.list_tools()).tools
120+
return self._tools_list
121+
122+
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
123+
"""Invoke a tool on the server."""
124+
if not self.session:
125+
raise UserError("Server not initialized. Make sure you call `connect()` first.")
126+
127+
return await self.session.call_tool(tool_name, arguments)
128+
129+
async def cleanup(self):
130+
"""Cleanup the server."""
131+
async with self._cleanup_lock:
132+
try:
133+
await self.exit_stack.aclose()
134+
self.session = None
135+
except Exception as e:
136+
logger.error(f"Error cleaning up server: {e}")
137+
138+
139+
class MCPServerStdioParams(TypedDict):
140+
"""Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
141+
import.
142+
"""
143+
144+
command: str
145+
"""The executable to run to start the server. For example, `python` or `node`."""
146+
147+
args: NotRequired[list[str]]
148+
"""Command line args to pass to the `command` executable. For example, `['foo.py']` or
149+
`['server.js', '--port', '8080']`."""
150+
151+
env: NotRequired[dict[str, str]]
152+
"""The environment variables to set for the server. ."""
153+
154+
cwd: NotRequired[str | Path]
155+
"""The working directory to use when spawning the process."""
156+
157+
encoding: NotRequired[str]
158+
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
159+
160+
encoding_error_handler: NotRequired[Literal["strict", "ignore", "replace"]]
161+
"""The text encoding error handler. Defaults to `strict`.
162+
163+
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
164+
explanations of possible values.
165+
"""
166+
167+
168+
class MCPServerStdio(_MCPServerWithClientSession):
169+
"""MCP server implementation that uses the stdio transport. See the [spec]
170+
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
171+
details.
172+
"""
173+
174+
def __init__(self, params: MCPServerStdioParams, cache_tools_list: bool = False):
175+
"""Create a new MCP server based on the stdio transport.
176+
177+
Args:
178+
params: The params that configure the server. This includes:
179+
- The command (e.g. `python` or `node`) that starts the server.
180+
- The args to pass to the server command (e.g. `foo.py` or `server.js`).
181+
- The environment variables to set for the server.
182+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
183+
cached and only fetched from the server once. If `False`, the tools list will be
184+
fetched from the server on each call to `list_tools()`. The cache can be
185+
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
186+
if you know the server will not change its tools list, because it can drastically
187+
improve latency (by avoiding a round-trip to the server every time).
188+
"""
189+
super().__init__(cache_tools_list)
190+
191+
self.params = StdioServerParameters(
192+
command=params["command"],
193+
args=params.get("args", []),
194+
env=params.get("env"),
195+
cwd=params.get("cwd"),
196+
encoding=params.get("encoding", "utf-8"),
197+
encoding_error_handler=params.get("encoding_error_handler", "strict"),
198+
)
199+
200+
def create_streams(
201+
self,
202+
) -> AbstractAsyncContextManager[
203+
tuple[
204+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
205+
MemoryObjectSendStream[JSONRPCMessage],
206+
]
207+
]:
208+
"""Create the streams for the server."""
209+
return stdio_client(self.params)
210+
211+
212+
class MCPServerSseParams(TypedDict):
213+
"""Mirrors the params in`mcp.client.sse.sse_client`."""
214+
215+
url: str
216+
"""The URL of the server."""
217+
218+
headers: NotRequired[dict[str, str]]
219+
"""The headers to send to the server."""
220+
221+
timeout: NotRequired[float]
222+
"""The timeout for the HTTP request. Defaults to 5 seconds."""
223+
224+
sse_read_timeout: NotRequired[float]
225+
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
226+
227+
228+
class MCPServerSse(_MCPServerWithClientSession):
229+
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
230+
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
231+
for details.
232+
"""
233+
234+
def __init__(self, params: MCPServerSseParams, cache_tools_list: bool = False):
235+
"""Create a new MCP server based on the HTTP with SSE transport.
236+
237+
Args:
238+
params: The params that configure the server. This includes:
239+
- The URL of the server.
240+
- The headers to send to the server.
241+
- The timeout for the HTTP request.
242+
- The timeout for the SSE connection.
243+
244+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
245+
cached and only fetched from the server once. If `False`, the tools list will be
246+
fetched from the server on each call to `list_tools()`. The cache can be
247+
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
248+
if you know the server will not change its tools list, because it can drastically
249+
improve latency (by avoiding a round-trip to the server every time).
250+
"""
251+
super().__init__(cache_tools_list)
252+
253+
self.params = params
254+
255+
def create_streams(
256+
self,
257+
) -> AbstractAsyncContextManager[
258+
tuple[
259+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
260+
MemoryObjectSendStream[JSONRPCMessage],
261+
]
262+
]:
263+
"""Create the streams for the server."""
264+
return sse_client(
265+
url=self.params["url"],
266+
headers=self.params.get("headers", None),
267+
timeout=self.params.get("timeout", 5),
268+
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
269+
)

0 commit comments

Comments
 (0)