Skip to content

Commit 68c800d

Browse files
committed
[2/n] Add MCP support to Runner
### Summary: This enables users to **use** MCP inside the SDK. 1. You add a list of MCP servers to `Agent`, via `mcp_server=[...]` 2. When an agent runs, we look up its MCP tools and add them to the list of tools. 3. When a tool call occurs, we call the relevant MCP server. Notes: 1. There's some refactoring to make sure we send the full list of tools to the Runner/Model etc. 2. Right now, you could have a locally defined tool that conflicts with an MCP defined tool. I didn't add errors for that, will do in a followup. ### Test Plan: See unit tests. Also has an end to end example next PR.
1 parent 300e12c commit 68c800d

14 files changed

+662
-35
lines changed

src/agents/_run_impl.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from .models.interface import ModelTracing
5151
from .run_context import RunContextWrapper, TContext
5252
from .stream_events import RunItemStreamEvent, StreamEvent
53-
from .tool import ComputerTool, FunctionTool, FunctionToolResult
53+
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
5454
from .tracing import (
5555
SpanError,
5656
Trace,
@@ -301,6 +301,7 @@ def process_model_response(
301301
cls,
302302
*,
303303
agent: Agent[Any],
304+
all_tools: list[Tool],
304305
response: ModelResponse,
305306
output_schema: AgentOutputSchema | None,
306307
handoffs: list[Handoff],
@@ -312,8 +313,8 @@ def process_model_response(
312313
computer_actions = []
313314

314315
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
315-
function_map = {tool.name: tool for tool in agent.tools if isinstance(tool, FunctionTool)}
316-
computer_tool = next((tool for tool in agent.tools if isinstance(tool, ComputerTool)), None)
316+
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
317+
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
317318

318319
for output in response.output:
319320
if isinstance(output, ResponseOutputMessage):

src/agents/agent.py

+20
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .handoffs import Handoff
1313
from .items import ItemHelpers
1414
from .logger import logger
15+
from .mcp import MCPUtil
1516
from .model_settings import ModelSettings
1617
from .models.interface import Model
1718
from .run_context import RunContextWrapper, TContext
@@ -21,6 +22,7 @@
2122

2223
if TYPE_CHECKING:
2324
from .lifecycle import AgentHooks
25+
from .mcp import MCPServer
2426
from .result import RunResult
2527

2628

@@ -107,6 +109,16 @@ class Agent(Generic[TContext]):
107109
tools: list[Tool] = field(default_factory=list)
108110
"""A list of tools that the agent can use."""
109111

112+
mcp_servers: list[MCPServer] = field(default_factory=list)
113+
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
114+
the agent can use. Every time the agent runs, it will include tools from these servers in the
115+
list of available tools.
116+
117+
NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
118+
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
119+
longer needed.
120+
"""
121+
110122
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
111123
"""A list of checks that run in parallel to the agent's execution, before generating a
112124
response. Runs only if the agent is the first agent in the chain.
@@ -205,3 +217,11 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s
205217
logger.error(f"Instructions must be a string or a function, got {self.instructions}")
206218

207219
return None
220+
221+
async def get_mcp_tools(self) -> list[Tool]:
222+
"""Fetches the available tools from the MCP servers."""
223+
return await MCPUtil.get_all_function_tools(self.mcp_servers)
224+
225+
async def get_all_tools(self) -> list[Tool]:
226+
"""All agent tools, including MCP tools and function tools."""
227+
return await MCPUtil.get_all_function_tools(self.mcp_servers) + self.tools

src/agents/run.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from openai.types.responses import ResponseCompletedEvent
99

10+
from agents.tool import Tool
11+
1012
from ._run_impl import (
1113
NextStepFinalOutput,
1214
NextStepHandoff,
@@ -177,7 +179,8 @@ async def run(
177179
# agent changes, or if the agent loop ends.
178180
if current_span is None:
179181
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
180-
tool_names = [t.name for t in current_agent.tools]
182+
all_tools = await cls._get_all_tools(current_agent)
183+
tool_names = [t.name for t in all_tools]
181184
if output_schema := cls._get_output_schema(current_agent):
182185
output_type_name = output_schema.output_type_name()
183186
else:
@@ -217,6 +220,7 @@ async def run(
217220
),
218221
cls._run_single_turn(
219222
agent=current_agent,
223+
all_tools=all_tools,
220224
original_input=original_input,
221225
generated_items=generated_items,
222226
hooks=hooks,
@@ -228,6 +232,7 @@ async def run(
228232
else:
229233
turn_result = await cls._run_single_turn(
230234
agent=current_agent,
235+
all_tools=all_tools,
231236
original_input=original_input,
232237
generated_items=generated_items,
233238
hooks=hooks,
@@ -627,7 +632,7 @@ async def _run_single_turn_streamed(
627632
system_prompt = await agent.get_system_prompt(context_wrapper)
628633

629634
handoffs = cls._get_handoffs(agent)
630-
635+
all_tools = await cls._get_all_tools(agent)
631636
model = cls._get_model(agent, run_config)
632637
model_settings = agent.model_settings.resolve(run_config.model_settings)
633638
final_response: ModelResponse | None = None
@@ -640,7 +645,7 @@ async def _run_single_turn_streamed(
640645
system_prompt,
641646
input,
642647
model_settings,
643-
agent.tools,
648+
all_tools,
644649
output_schema,
645650
handoffs,
646651
get_model_tracing_impl(
@@ -677,6 +682,7 @@ async def _run_single_turn_streamed(
677682
pre_step_items=streamed_result.new_items,
678683
new_response=final_response,
679684
output_schema=output_schema,
685+
all_tools=all_tools,
680686
handoffs=handoffs,
681687
hooks=hooks,
682688
context_wrapper=context_wrapper,
@@ -691,6 +697,7 @@ async def _run_single_turn(
691697
cls,
692698
*,
693699
agent: Agent[TContext],
700+
all_tools: list[Tool],
694701
original_input: str | list[TResponseInputItem],
695702
generated_items: list[RunItem],
696703
hooks: RunHooks[TContext],
@@ -721,6 +728,7 @@ async def _run_single_turn(
721728
system_prompt,
722729
input,
723730
output_schema,
731+
all_tools,
724732
handoffs,
725733
context_wrapper,
726734
run_config,
@@ -732,6 +740,7 @@ async def _run_single_turn(
732740
pre_step_items=generated_items,
733741
new_response=new_response,
734742
output_schema=output_schema,
743+
all_tools=all_tools,
735744
handoffs=handoffs,
736745
hooks=hooks,
737746
context_wrapper=context_wrapper,
@@ -743,6 +752,7 @@ async def _get_single_step_result_from_response(
743752
cls,
744753
*,
745754
agent: Agent[TContext],
755+
all_tools: list[Tool],
746756
original_input: str | list[TResponseInputItem],
747757
pre_step_items: list[RunItem],
748758
new_response: ModelResponse,
@@ -754,6 +764,7 @@ async def _get_single_step_result_from_response(
754764
) -> SingleStepResult:
755765
processed_response = RunImpl.process_model_response(
756766
agent=agent,
767+
all_tools=all_tools,
757768
response=new_response,
758769
output_schema=output_schema,
759770
handoffs=handoffs,
@@ -853,6 +864,7 @@ async def _get_new_response(
853864
system_prompt: str | None,
854865
input: list[TResponseInputItem],
855866
output_schema: AgentOutputSchema | None,
867+
all_tools: list[Tool],
856868
handoffs: list[Handoff],
857869
context_wrapper: RunContextWrapper[TContext],
858870
run_config: RunConfig,
@@ -863,7 +875,7 @@ async def _get_new_response(
863875
system_instructions=system_prompt,
864876
input=input,
865877
model_settings=model_settings,
866-
tools=agent.tools,
878+
tools=all_tools,
867879
output_schema=output_schema,
868880
handoffs=handoffs,
869881
tracing=get_model_tracing_impl(
@@ -892,6 +904,10 @@ def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
892904
handoffs.append(handoff(handoff_item))
893905
return handoffs
894906

907+
@classmethod
908+
async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]:
909+
return await agent.get_all_tools()
910+
895911
@classmethod
896912
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
897913
if isinstance(run_config.model, Model):

tests/mcp/__init__.py

Whitespace-only changes.

tests/mcp/conftest.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import os
2+
import sys
3+
4+
5+
# Skip MCP tests on Python 3.9
6+
def pytest_ignore_collect(collection_path, config):
7+
if sys.version_info[:2] == (3, 9):
8+
this_dir = os.path.dirname(__file__)
9+
10+
if str(collection_path).startswith(this_dir):
11+
return True

tests/mcp/helpers.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import json
2+
import shutil
3+
from typing import Any
4+
5+
from mcp import Tool as MCPTool
6+
from mcp.types import CallToolResult, TextContent
7+
8+
from agents.mcp import MCPServer
9+
10+
tee = shutil.which("tee") or ""
11+
assert tee, "tee not found"
12+
13+
14+
# Added dummy stream classes for patching stdio_client to avoid real I/O during tests
15+
class DummyStream:
16+
async def send(self, msg):
17+
pass
18+
19+
async def receive(self):
20+
raise Exception("Dummy receive not implemented")
21+
22+
23+
class DummyStreamsContextManager:
24+
async def __aenter__(self):
25+
return (DummyStream(), DummyStream())
26+
27+
async def __aexit__(self, exc_type, exc_val, exc_tb):
28+
pass
29+
30+
31+
class FakeMCPServer(MCPServer):
32+
def __init__(self, tools: list[MCPTool] | None = None):
33+
self.tools: list[MCPTool] = tools or []
34+
self.tool_calls: list[str] = []
35+
self.tool_results: list[str] = []
36+
37+
def add_tool(self, name: str, input_schema: dict[str, Any]):
38+
self.tools.append(MCPTool(name=name, inputSchema=input_schema))
39+
40+
async def connect(self):
41+
pass
42+
43+
async def cleanup(self):
44+
pass
45+
46+
async def list_tools(self):
47+
return self.tools
48+
49+
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
50+
self.tool_calls.append(tool_name)
51+
self.tool_results.append(f"result_{tool_name}_{json.dumps(arguments)}")
52+
return CallToolResult(
53+
content=[TextContent(text=self.tool_results[-1], type="text")],
54+
)

tests/mcp/test_caching.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from unittest.mock import AsyncMock, patch
2+
3+
import pytest
4+
from mcp.types import ListToolsResult, Tool as MCPTool
5+
6+
from agents.mcp import MCPServerStdio
7+
8+
from .helpers import DummyStreamsContextManager, tee
9+
10+
11+
@pytest.mark.asyncio
12+
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
13+
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
14+
@patch("mcp.client.session.ClientSession.list_tools")
15+
async def test_server_caching_works(
16+
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
17+
):
18+
"""Test that if we turn caching on, the list of tools is cached and not fetched from the server
19+
on each call to `list_tools()`.
20+
"""
21+
server = MCPServerStdio(
22+
params={
23+
"command": tee,
24+
},
25+
cache_tools_list=True,
26+
)
27+
28+
tools = [
29+
MCPTool(name="tool1", inputSchema={}),
30+
MCPTool(name="tool2", inputSchema={}),
31+
]
32+
33+
mock_list_tools.return_value = ListToolsResult(tools=tools)
34+
35+
async with server:
36+
# Call list_tools() multiple times
37+
tools = await server.list_tools()
38+
assert tools == tools
39+
40+
assert mock_list_tools.call_count == 1, "list_tools() should have been called once"
41+
42+
# Call list_tools() again, should return the cached value
43+
tools = await server.list_tools()
44+
assert tools == tools
45+
46+
assert mock_list_tools.call_count == 1, "list_tools() should not have been called again"
47+
48+
# Invalidate the cache and call list_tools() again
49+
server.invalidate_tools_cache()
50+
tools = await server.list_tools()
51+
assert tools == tools
52+
53+
assert mock_list_tools.call_count == 2, "list_tools() should be called again"
54+
55+
# Without invalidating the cache, calling list_tools() again should return the cached value
56+
tools = await server.list_tools()
57+
assert tools == tools

tests/mcp/test_connect_disconnect.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from unittest.mock import AsyncMock, patch
2+
3+
import pytest
4+
from mcp.types import ListToolsResult, Tool as MCPTool
5+
6+
from agents.mcp import MCPServerStdio
7+
8+
from .helpers import DummyStreamsContextManager, tee
9+
10+
11+
@pytest.mark.asyncio
12+
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
13+
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
14+
@patch("mcp.client.session.ClientSession.list_tools")
15+
async def test_async_ctx_manager_works(
16+
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
17+
):
18+
"""Test that the async context manager works."""
19+
server = MCPServerStdio(
20+
params={
21+
"command": tee,
22+
},
23+
cache_tools_list=True,
24+
)
25+
26+
tools = [
27+
MCPTool(name="tool1", inputSchema={}),
28+
MCPTool(name="tool2", inputSchema={}),
29+
]
30+
31+
mock_list_tools.return_value = ListToolsResult(tools=tools)
32+
33+
assert server.session is None, "Server should not be connected"
34+
35+
async with server:
36+
assert server.session is not None, "Server should be connected"
37+
38+
assert server.session is None, "Server should be disconnected"
39+
40+
41+
@pytest.mark.asyncio
42+
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
43+
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
44+
@patch("mcp.client.session.ClientSession.list_tools")
45+
async def test_manual_connect_disconnect_works(
46+
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
47+
):
48+
"""Test that the async context manager works."""
49+
server = MCPServerStdio(
50+
params={
51+
"command": tee,
52+
},
53+
cache_tools_list=True,
54+
)
55+
56+
tools = [
57+
MCPTool(name="tool1", inputSchema={}),
58+
MCPTool(name="tool2", inputSchema={}),
59+
]
60+
61+
mock_list_tools.return_value = ListToolsResult(tools=tools)
62+
63+
assert server.session is None, "Server should not be connected"
64+
65+
await server.connect()
66+
assert server.session is not None, "Server should be connected"
67+
68+
await server.cleanup()
69+
assert server.session is None, "Server should be disconnected"

0 commit comments

Comments
 (0)