Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/n] Add MCP support to Runner #321

Merged
merged 2 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"typing-extensions>=4.12.2, <5",
"requests>=2.0, <3",
"types-requests>=2.0, <3",
"mcp; python_version >= '3.10'",
]
classifiers = [
"Typing :: Typed",
Expand Down
7 changes: 4 additions & 3 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from .models.interface import ModelTracing
from .run_context import RunContextWrapper, TContext
from .stream_events import RunItemStreamEvent, StreamEvent
from .tool import ComputerTool, FunctionTool, FunctionToolResult
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
from .tracing import (
SpanError,
Trace,
Expand Down Expand Up @@ -301,6 +301,7 @@ def process_model_response(
cls,
*,
agent: Agent[Any],
all_tools: list[Tool],
response: ModelResponse,
output_schema: AgentOutputSchema | None,
handoffs: list[Handoff],
Expand All @@ -312,8 +313,8 @@ def process_model_response(
computer_actions = []

handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
function_map = {tool.name: tool for tool in agent.tools if isinstance(tool, FunctionTool)}
computer_tool = next((tool for tool in agent.tools if isinstance(tool, ComputerTool)), None)
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)

for output in response.output:
if isinstance(output, ResponseOutputMessage):
Expand Down
20 changes: 20 additions & 0 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .handoffs import Handoff
from .items import ItemHelpers
from .logger import logger
from .mcp import MCPUtil
from .model_settings import ModelSettings
from .models.interface import Model
from .run_context import RunContextWrapper, TContext
Expand All @@ -21,6 +22,7 @@

if TYPE_CHECKING:
from .lifecycle import AgentHooks
from .mcp import MCPServer
from .result import RunResult


Expand Down Expand Up @@ -107,6 +109,16 @@ class Agent(Generic[TContext]):
tools: list[Tool] = field(default_factory=list)
"""A list of tools that the agent can use."""

mcp_servers: list[MCPServer] = field(default_factory=list)
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
the agent can use. Every time the agent runs, it will include tools from these servers in the
list of available tools.

NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
longer needed.
"""

input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
"""A list of checks that run in parallel to the agent's execution, before generating a
response. Runs only if the agent is the first agent in the chain.
Expand Down Expand Up @@ -205,3 +217,11 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s
logger.error(f"Instructions must be a string or a function, got {self.instructions}")

return None

async def get_mcp_tools(self) -> list[Tool]:
"""Fetches the available tools from the MCP servers."""
return await MCPUtil.get_all_function_tools(self.mcp_servers)

async def get_all_tools(self) -> list[Tool]:
"""All agent tools, including MCP tools and function tools."""
return await MCPUtil.get_all_function_tools(self.mcp_servers) + self.tools
21 changes: 21 additions & 0 deletions src/agents/mcp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
try:
from .server import (
MCPServer,
MCPServerSse,
MCPServerSseParams,
MCPServerStdio,
MCPServerStdioParams,
)
except ImportError:
pass

from .util import MCPUtil

__all__ = [
"MCPServer",
"MCPServerSse",
"MCPServerSseParams",
"MCPServerStdio",
"MCPServerStdioParams",
"MCPUtil",
]
94 changes: 94 additions & 0 deletions src/agents/mcp/mcp_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import functools
import json
from typing import Any

from mcp.types import Tool as MCPTool

from .. import _debug
from ..exceptions import AgentsException, ModelBehaviorError, UserError
from ..logger import logger
from ..run_context import RunContextWrapper
from ..tool import FunctionTool, Tool
from .server import MCPServer


class MCPUtil:
"""Set of utilities for interop between MCP and Agents SDK tools."""

@classmethod
async def get_all_function_tools(cls, servers: list[MCPServer]) -> list[Tool]:
"""Get all function tools from a list of MCP servers."""
tools = []
tool_names: set[str] = set()
for server in servers:
server_tools = await cls.get_function_tools(server)
server_tool_names = {tool.name for tool in server_tools}
if len(server_tool_names & tool_names) > 0:
raise UserError(
f"Duplicate tool names found across MCP servers: "
f"{server_tool_names & tool_names}"
)
tool_names.update(server_tool_names)
tools.extend(server_tools)

return tools

@classmethod
async def get_function_tools(cls, server: MCPServer) -> list[Tool]:
"""Get all function tools from a single MCP server."""
tools = await server.list_tools()
return [cls.to_function_tool(tool, server) for tool in tools]

@classmethod
def to_function_tool(cls, tool: MCPTool, server: MCPServer) -> FunctionTool:
"""Convert an MCP tool to an Agents SDK function tool."""
invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool)
return FunctionTool(
name=tool.name,
description=tool.description or "",
params_json_schema=tool.inputSchema,
on_invoke_tool=invoke_func,
strict_json_schema=False,
)

@classmethod
async def invoke_mcp_tool(
cls, server: MCPServer, tool: MCPTool, context: RunContextWrapper[Any], input_json: str
) -> str:
"""Invoke an MCP tool and return the result as a string."""
try:
json_data: dict[str, Any] = json.loads(input_json) if input_json else {}
except Exception as e:
if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invalid JSON input for tool {tool.name}")
else:
logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}")
raise ModelBehaviorError(
f"Invalid JSON input for tool {tool.name}: {input_json}"
) from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invoking MCP tool {tool.name}")
else:
logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}")

try:
result = await server.call_tool(tool.name, json_data)
except Exception as e:
logger.error(f"Error invoking MCP tool {tool.name}: {e}")
raise AgentsException(f"Error invoking MCP tool {tool.name}: {e}") from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"MCP tool {tool.name} completed.")
else:
logger.debug(f"MCP tool {tool.name} returned {result}")

# The MCP tool result is a list of content items, whereas OpenAI tool outputs are a single
# string. We'll try to convert.
if len(result.content) == 1:
return result.content[0].model_dump_json()
elif len(result.content) > 1:
return json.dumps([item.model_dump() for item in result.content])
else:
logger.error(f"Errored MCP tool result: {result}")
return "Error running tool."
Loading