Skip to content

Commit d4021bf

Browse files
committed
Add support tool approval and deferred tools
1 parent fd6656d commit d4021bf

File tree

1 file changed

+44
-13
lines changed
  • pydantic_ai_slim/pydantic_ai/durable_exec/restate

1 file changed

+44
-13
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/restate/_toolset.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from collections.abc import Callable
44
from dataclasses import dataclass
5-
from typing import Any, Generic
5+
from typing import Any, Literal
66

7-
from restate import Context, RunOptions
7+
from restate import Context, RunOptions, TerminalError
88

99
from pydantic_ai import ToolDefinition
1010
from pydantic_ai._run_context import AgentDepsT
11+
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
1112
from pydantic_ai.mcp import MCPServer, ToolResult
1213
from pydantic_ai.tools import RunContext
1314
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
@@ -20,40 +21,71 @@
2021
class RestateContextRunResult:
2122
"""A simple wrapper for tool results to be used with Restate's run_typed."""
2223

24+
kind: Literal['output', 'call_deferred', 'approval_required']
2325
output: Any
2426

2527

28+
CONTEXT_RUN_SERDE = PydanticTypeAdapter(RestateContextRunResult)
29+
30+
2631
@dataclass
2732
class RestateMCPGetToolsContextRunResult:
2833
"""A simple wrapper for tool results to be used with Restate's run_typed."""
2934

3035
output: dict[str, ToolDefinition]
3136

3237

38+
MCP_GET_TOOLS_SERDE = PydanticTypeAdapter(RestateMCPGetToolsContextRunResult)
39+
40+
3341
@dataclass
34-
class RestateMCPToolRunResult(Generic[AgentDepsT]):
42+
class RestateMCPToolRunResult:
3543
"""A simple wrapper for tool results to be used with Restate's run_typed."""
3644

3745
output: ToolResult
3846

3947

48+
MCP_RUN_SERDE = PydanticTypeAdapter(RestateMCPToolRunResult)
49+
50+
4051
class RestateContextRunToolSet(WrapperToolset[AgentDepsT]):
4152
"""A toolset that automatically wraps tool calls with restate's `ctx.run_typed()`."""
4253

4354
def __init__(self, wrapped: AbstractToolset[AgentDepsT], context: Context):
4455
super().__init__(wrapped)
4556
self._context = context
46-
self.options = RunOptions[RestateContextRunResult](serde=PydanticTypeAdapter(RestateContextRunResult))
57+
self.options = RunOptions[RestateContextRunResult](serde=CONTEXT_RUN_SERDE)
4758

4859
async def call_tool(
4960
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
5061
) -> Any:
5162
async def action() -> RestateContextRunResult:
52-
output = await self.wrapped.call_tool(name, tool_args, ctx, tool)
53-
return RestateContextRunResult(output=output)
63+
try:
64+
# A tool may raise ModelRetry, CallDeferred, ApprovalRequired, or UserError
65+
# to signal special conditions to the caller.
66+
# Since, restate ctx.run() will retry this exception we need to convert these exceptions
67+
# to a return value and handle them outside of the ctx.run().
68+
output = await self.wrapped.call_tool(name, tool_args, ctx, tool)
69+
return RestateContextRunResult(kind='output', output=output)
70+
except ModelRetry:
71+
# we let restate to retry this
72+
raise
73+
except CallDeferred:
74+
return RestateContextRunResult(kind='call_deferred', output=None)
75+
except ApprovalRequired:
76+
return RestateContextRunResult(kind='approval_required', output=None)
77+
except UserError as e:
78+
raise TerminalError(str(e)) from e
5479

5580
res = await self._context.run_typed(f'Calling {name}', action, self.options)
56-
return res.output
81+
82+
if res.kind == 'call_deferred':
83+
raise CallDeferred()
84+
elif res.kind == 'approval_required':
85+
raise ApprovalRequired()
86+
else:
87+
assert res.kind == 'output'
88+
return res.output
5789

5890
def visit_and_replace(
5991
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
@@ -68,8 +100,6 @@ def __init__(self, wrapped: MCPServer, context: Context):
68100
super().__init__(wrapped)
69101
self._wrapped = wrapped
70102
self._context = context
71-
self._mcp_tool_run_serde = PydanticTypeAdapter(RestateMCPToolRunResult[AgentDepsT])
72-
self._mcp_get_tools_serde = PydanticTypeAdapter(RestateMCPGetToolsContextRunResult)
73103

74104
def visit_and_replace(
75105
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
@@ -79,11 +109,12 @@ def visit_and_replace(
79109
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
80110
async def get_tools_in_context() -> RestateMCPGetToolsContextRunResult:
81111
res = await self._wrapped.get_tools(ctx)
82-
# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
112+
# ToolsetTool is not serializable as it holds a SchemaValidator
113+
# (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
83114
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
84115
return RestateMCPGetToolsContextRunResult(output={name: tool.tool_def for name, tool in res.items()})
85116

86-
options = RunOptions(serde=self._mcp_get_tools_serde)
117+
options = RunOptions(serde=MCP_GET_TOOLS_SERDE)
87118

88119
tool_defs = await self._context.run_typed('get mcp tools', get_tools_in_context, options)
89120

@@ -100,11 +131,11 @@ async def call_tool(
100131
ctx: RunContext[AgentDepsT],
101132
tool: ToolsetTool[AgentDepsT],
102133
) -> ToolResult:
103-
async def call_tool_in_context() -> RestateMCPToolRunResult[AgentDepsT]:
134+
async def call_tool_in_context() -> RestateMCPToolRunResult:
104135
res = await self._wrapped.call_tool(name, tool_args, ctx, tool)
105136
return RestateMCPToolRunResult(output=res)
106137

107-
options = RunOptions(serde=self._mcp_tool_run_serde)
138+
options = RunOptions(serde=MCP_RUN_SERDE)
108139
res = await self._context.run_typed(f'Calling mcp tool {name}', call_tool_in_context, options)
109140

110141
return res.output

0 commit comments

Comments
 (0)