22
33from collections .abc import Callable
44from dataclasses import dataclass
5- from typing import Any , Generic
5+ from typing import Any , Literal
66
7- from restate import Context , RunOptions
7+ from restate import Context , RunOptions , TerminalError
88
99from pydantic_ai import ToolDefinition
1010from pydantic_ai ._run_context import AgentDepsT
11+ from pydantic_ai .exceptions import ApprovalRequired , CallDeferred , ModelRetry , UserError
1112from pydantic_ai .mcp import MCPServer , ToolResult
1213from pydantic_ai .tools import RunContext
1314from pydantic_ai .toolsets .abstract import AbstractToolset , ToolsetTool
2021class RestateContextRunResult :
2122 """A simple wrapper for tool results to be used with Restate's run_typed."""
2223
24+ kind : Literal ['output' , 'call_deferred' , 'approval_required' ]
2325 output : Any
2426
2527
28+ CONTEXT_RUN_SERDE = PydanticTypeAdapter (RestateContextRunResult )
29+
30+
2631@dataclass
2732class RestateMCPGetToolsContextRunResult :
2833 """A simple wrapper for tool results to be used with Restate's run_typed."""
2934
3035 output : dict [str , ToolDefinition ]
3136
3237
38+ MCP_GET_TOOLS_SERDE = PydanticTypeAdapter (RestateMCPGetToolsContextRunResult )
39+
40+
3341@dataclass
34- class RestateMCPToolRunResult ( Generic [ AgentDepsT ]) :
42+ class RestateMCPToolRunResult :
3543 """A simple wrapper for tool results to be used with Restate's run_typed."""
3644
3745 output : ToolResult
3846
3947
48+ MCP_RUN_SERDE = PydanticTypeAdapter (RestateMCPToolRunResult )
49+
50+
4051class RestateContextRunToolSet (WrapperToolset [AgentDepsT ]):
4152 """A toolset that automatically wraps tool calls with restate's `ctx.run_typed()`."""
4253
4354 def __init__ (self , wrapped : AbstractToolset [AgentDepsT ], context : Context ):
4455 super ().__init__ (wrapped )
4556 self ._context = context
46- self .options = RunOptions [RestateContextRunResult ](serde = PydanticTypeAdapter ( RestateContextRunResult ) )
57+ self .options = RunOptions [RestateContextRunResult ](serde = CONTEXT_RUN_SERDE )
4758
4859 async def call_tool (
4960 self , name : str , tool_args : dict [str , Any ], ctx : RunContext [AgentDepsT ], tool : ToolsetTool [AgentDepsT ]
5061 ) -> Any :
5162 async def action () -> RestateContextRunResult :
52- output = await self .wrapped .call_tool (name , tool_args , ctx , tool )
53- return RestateContextRunResult (output = output )
63+ try :
64+ # A tool may raise ModelRetry, CallDeferred, ApprovalRequired, or UserError
65+ # to signal special conditions to the caller.
66+ # Since, restate ctx.run() will retry this exception we need to convert these exceptions
67+ # to a return value and handle them outside of the ctx.run().
68+ output = await self .wrapped .call_tool (name , tool_args , ctx , tool )
69+ return RestateContextRunResult (kind = 'output' , output = output )
70+ except ModelRetry :
71+ # we let restate to retry this
72+ raise
73+ except CallDeferred :
74+ return RestateContextRunResult (kind = 'call_deferred' , output = None )
75+ except ApprovalRequired :
76+ return RestateContextRunResult (kind = 'approval_required' , output = None )
77+ except UserError as e :
78+ raise TerminalError (str (e )) from e
5479
5580 res = await self ._context .run_typed (f'Calling { name } ' , action , self .options )
56- return res .output
81+
82+ if res .kind == 'call_deferred' :
83+ raise CallDeferred ()
84+ elif res .kind == 'approval_required' :
85+ raise ApprovalRequired ()
86+ else :
87+ assert res .kind == 'output'
88+ return res .output
5789
5890 def visit_and_replace (
5991 self , visitor : Callable [[AbstractToolset [AgentDepsT ]], AbstractToolset [AgentDepsT ]]
@@ -68,8 +100,6 @@ def __init__(self, wrapped: MCPServer, context: Context):
68100 super ().__init__ (wrapped )
69101 self ._wrapped = wrapped
70102 self ._context = context
71- self ._mcp_tool_run_serde = PydanticTypeAdapter (RestateMCPToolRunResult [AgentDepsT ])
72- self ._mcp_get_tools_serde = PydanticTypeAdapter (RestateMCPGetToolsContextRunResult )
73103
74104 def visit_and_replace (
75105 self , visitor : Callable [[AbstractToolset [AgentDepsT ]], AbstractToolset [AgentDepsT ]]
@@ -79,11 +109,12 @@ def visit_and_replace(
79109 async def get_tools (self , ctx : RunContext [AgentDepsT ]) -> dict [str , ToolsetTool [AgentDepsT ]]:
80110 async def get_tools_in_context () -> RestateMCPGetToolsContextRunResult :
81111 res = await self ._wrapped .get_tools (ctx )
82- # ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
112+ # ToolsetTool is not serializable as it holds a SchemaValidator
113+ # (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
83114 # so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
84115 return RestateMCPGetToolsContextRunResult (output = {name : tool .tool_def for name , tool in res .items ()})
85116
86- options = RunOptions (serde = self . _mcp_get_tools_serde )
117+ options = RunOptions (serde = MCP_GET_TOOLS_SERDE )
87118
88119 tool_defs = await self ._context .run_typed ('get mcp tools' , get_tools_in_context , options )
89120
@@ -100,11 +131,11 @@ async def call_tool(
100131 ctx : RunContext [AgentDepsT ],
101132 tool : ToolsetTool [AgentDepsT ],
102133 ) -> ToolResult :
103- async def call_tool_in_context () -> RestateMCPToolRunResult [ AgentDepsT ] :
134+ async def call_tool_in_context () -> RestateMCPToolRunResult :
104135 res = await self ._wrapped .call_tool (name , tool_args , ctx , tool )
105136 return RestateMCPToolRunResult (output = res )
106137
107- options = RunOptions (serde = self . _mcp_tool_run_serde )
138+ options = RunOptions (serde = MCP_RUN_SERDE )
108139 res = await self ._context .run_typed (f'Calling mcp tool { name } ' , call_tool_in_context , options )
109140
110141 return res .output
0 commit comments