Skip to content

Commit 807ace2

Browse files
committed
Remove the agent provider, just use a RestateAgent
1 parent d4021bf commit 807ace2

File tree

2 files changed

+63
-107
lines changed

2 files changed

+63
-107
lines changed
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from ._agent import RestateAgent, RestateAgentProvider
1+
from ._agent import RestateAgent
22
from ._model import RestateModelWrapper
33
from ._serde import PydanticTypeAdapter
44
from ._toolset import RestateContextRunToolSet
55

66
__all__ = [
77
'RestateModelWrapper',
88
'RestateAgent',
9-
'RestateAgentProvider',
109
'PydanticTypeAdapter',
1110
'RestateContextRunToolSet',
1211
]

pydantic_ai_slim/pydantic_ai/durable_exec/restate/_agent.py

Lines changed: 62 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable, Iterator, Sequence
3+
from collections.abc import Iterator, Sequence
44
from contextlib import contextmanager
5-
from typing import Any, Generic, Never, overload
5+
from typing import Any, Never, overload
66

77
from restate import Context, TerminalError
88

@@ -24,125 +24,81 @@
2424
from ._toolset import RestateContextRunToolSet
2525

2626

27-
class RestateAgentProvider(Generic[AgentDepsT, OutputDataT]):
28-
def __init__(self, wrapped: AbstractAgent[AgentDepsT, OutputDataT], *, max_attempts: int = 3):
29-
if not isinstance(wrapped.model, Model):
30-
raise TerminalError(
31-
'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.'
32-
)
33-
# here we collect all the configuration that will be passed to the RestateAgent
34-
# the actual context will be provided at runtime.
35-
self.wrapped = wrapped
36-
self.model = wrapped.model
37-
self.max_attempts = max_attempts
38-
39-
def create_agent(self, context: Context) -> AbstractAgent[AgentDepsT, OutputDataT]:
40-
"""Create an agent instance with the given Restate context.
41-
42-
Use this method to create an agent that is tied to a specific Restate context.
43-
With this agent, all operations will be executed within the provided context,
44-
enabling features like retries and durable steps.
45-
Note that the agent will automatically wrap tool calls with restate's `ctx.run()`.
46-
47-
Example:
48-
...
49-
agent_provider = RestateAgentProvider(weather_agent)
50-
51-
weather = restate.Service('weather')
52-
53-
@weather.handler()
54-
async def get_weather(ctx: restate.Context, city: str):
55-
agent = agent_provider.create_agent_from_context(ctx)
56-
result = await agent.run(f'What is the weather in {city}?')
57-
return result.output
58-
...
59-
60-
Args:
61-
context: The Restate context to use for the agent.
62-
auto_wrap_tool_calls: Whether to automatically wrap tool calls with restate's ctx.run() (durable step), True by default.
63-
64-
Returns:
65-
A RestateAgent instance that uses the provided context for its operations.
66-
"""
67-
68-
def get_context(_unused: AgentDepsT) -> Context:
69-
return context
70-
71-
builder = self
72-
return RestateAgent(builder=builder, get_context=get_context, auto_wrap_tools=True)
73-
74-
def create_agent_with_advanced_tools(
75-
self, get_context: Callable[[AgentDepsT], Context]
76-
) -> AbstractAgent[AgentDepsT, OutputDataT]:
77-
"""Create an agent instance that is able to obtain Restate context from its dependencies.
27+
class RestateAgent(WrapperAgent[AgentDepsT, OutputDataT]):
28+
"""An agent that integrates with Restate framework for building resilient applications.
7829
79-
Use this method, if you are planning to use restate's context inside the tools (for rpc, timers, multi step tools etc.)
80-
To obtain a context inside a tool you can add a dependency that has a `restate_context` attribute, and provide a `get_context` extractor
81-
function that extracts the context from the dependencies at runtime.
30+
This agent wraps an existing agent with Restate context capabilities, providing
31+
automatic retries and durable execution for all operations. By default, tool calls
32+
are automatically wrapped with Restate's execution model.
8233
83-
Note: that the agent will NOT automatically wrap tool calls with restate's `ctx.run()`
84-
since the tools may use the context in different ways.
34+
Example:
35+
...
8536
86-
Example:
87-
...
37+
weather = restate.Service('weather')
8838
89-
@dataclass
90-
WeatherDeps:
91-
...
92-
restate_context: Context
39+
@weather.handler()
40+
async def get_weather(ctx: restate.Context, city: str):
41+
agent = RestateAgent(weather_agent, context=ctx)
42+
result = await agent.run(f'What is the weather in {city}?')
43+
return result.output
44+
...
9345
94-
weather_agent = Agent(..., deps_type=WeatherDeps, ...)
46+
For advanced scenarios, you can disable automatic tool wrapping by setting
47+
`disable_auto_wrapping_tools=True`. This allows direct usage of Restate context
48+
within your tools for features like RPC calls, timers, and multi-step operations.
9549
96-
@weather_agent.tool
97-
async def get_lat_lng(ctx: RunContext[WeatherDeps], location_description: str) -> LatLng:
98-
restate_context = ctx.deps.restate_context
99-
lat = await restate_context.run(...) # <---- note the direct usage of the restate context
100-
lng = await restate_context.run(...)
101-
return LatLng(lat, lng)
50+
When automatic wrapping is disabled, function tools will NOT be automatically executed
51+
within Restate's `ctx.run()` context, giving you full control over how the
52+
Restate context is used within your tool implementations.
53+
But model calls, and MCP tool calls will still be automatically wrapped.
10254
103-
agent = RestateAgentProvider(weather_agent).create_agent_from_deps(lambda deps: deps.restate_context)
55+
Example:
56+
...
10457
105-
weather = restate.Service('weather')
58+
@dataclass
59+
WeatherDeps:
60+
...
61+
restate_context: Context
10662
107-
@weather.handler()
108-
async def get_weather(ctx: restate.Context, city: str):
109-
result = await agent.run(f'What is the weather in {city}?', deps=WeatherDeps(restate_context=ctx, ...))
110-
return result.output
111-
...
63+
weather_agent = Agent(..., deps_type=WeatherDeps, ...)
11264
113-
Args:
114-
get_context: A callable that extracts the Restate context from the agent's dependencies.
65+
@weather_agent.tool
66+
async def get_lat_lng(ctx: RunContext[WeatherDeps], location_description: str) -> LatLng:
67+
restate_context = ctx.deps.restate_context
68+
lat = await restate_context.run(...) # <---- note the direct usage of the restate context
69+
lng = await restate_context.run(...)
70+
return LatLng(lat, lng)
11571
116-
Returns:
117-
A RestateAgent instance that uses the provided context extractor to obtain the Restate context at runtime.
11872
119-
"""
120-
builder = self
121-
return RestateAgent(builder=builder, get_context=get_context, auto_wrap_tools=False)
73+
weather = restate.Service('weather')
12274
75+
@weather.handler()
76+
async def get_weather(ctx: restate.Context, city: str):
77+
agent = RestateAgent(weather_agent, context=ctx)
78+
result = await agent.run(f'What is the weather in {city}?', deps=WeatherDeps(restate_context=ctx, ...))
79+
return result.output
80+
...
12381
124-
class RestateAgent(WrapperAgent[AgentDepsT, OutputDataT]):
125-
"""An agent that integrates with the Restate framework for resilient applications."""
82+
"""
12683

12784
def __init__(
12885
self,
129-
builder: RestateAgentProvider[AgentDepsT, OutputDataT],
130-
get_context: Callable[[AgentDepsT], Context],
131-
auto_wrap_tools: bool,
86+
wrapped: AbstractAgent[AgentDepsT, OutputDataT],
87+
restate_context: Context,
88+
*,
89+
disable_auto_wrapping_tools: bool = False,
13290
):
133-
super().__init__(builder.wrapped)
134-
self._builder = builder
135-
self._get_context = get_context
136-
self._auto_wrap_tools = auto_wrap_tools
137-
138-
@contextmanager
139-
def _restate_overrides(self, context: Context) -> Iterator[None]:
140-
model = RestateModelWrapper(self._builder.model, context, max_attempts=self._builder.max_attempts)
91+
super().__init__(wrapped)
92+
if not isinstance(wrapped.model, Model):
93+
raise TerminalError(
94+
'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.'
95+
)
96+
self._model = RestateModelWrapper(wrapped.model, restate_context, max_attempts=3)
14197

14298
def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
14399
"""Set the Restate context for the toolset, wrapping tools if needed."""
144-
if isinstance(toolset, FunctionToolset) and self._auto_wrap_tools:
145-
return RestateContextRunToolSet(toolset, context)
100+
if isinstance(toolset, FunctionToolset) and not disable_auto_wrapping_tools:
101+
return RestateContextRunToolSet(toolset, restate_context)
146102
try:
147103
from pydantic_ai.mcp import MCPServer
148104

@@ -151,14 +107,16 @@ def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDe
151107
pass
152108
else:
153109
if isinstance(toolset, MCPServer):
154-
return RestateMCPServer(toolset, context)
110+
return RestateMCPServer(toolset, restate_context)
155111

156112
return toolset
157113

158-
toolsets = [toolset.visit_and_replace(set_context) for toolset in self._builder.wrapped.toolsets]
114+
self._toolsets = [toolset.visit_and_replace(set_context) for toolset in wrapped.toolsets]
159115

116+
@contextmanager
117+
def _restate_overrides(self) -> Iterator[None]:
160118
with (
161-
super().override(model=model, toolsets=toolsets, tools=[]),
119+
super().override(model=self._model, toolsets=self._toolsets, tools=[]),
162120
self.sequential_tool_calls(),
163121
):
164122
yield
@@ -255,8 +213,7 @@ async def main():
255213
raise TerminalError(
256214
'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.'
257215
)
258-
context = self._get_context(deps)
259-
with self._restate_overrides(context):
216+
with self._restate_overrides():
260217
return await super(WrapperAgent, self).run(
261218
user_prompt=user_prompt,
262219
output_type=output_type,

0 commit comments

Comments
 (0)