11from __future__ import annotations
22
3- from collections .abc import Callable , Iterator , Sequence
3+ from collections .abc import Iterator , Sequence
44from contextlib import contextmanager
5- from typing import Any , Generic , Never , overload
5+ from typing import Any , Never , overload
66
77from restate import Context , TerminalError
88
2424from ._toolset import RestateContextRunToolSet
2525
2626
27- class RestateAgentProvider (Generic [AgentDepsT , OutputDataT ]):
28- def __init__ (self , wrapped : AbstractAgent [AgentDepsT , OutputDataT ], * , max_attempts : int = 3 ):
29- if not isinstance (wrapped .model , Model ):
30- raise TerminalError (
31- 'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.'
32- )
33- # here we collect all the configuration that will be passed to the RestateAgent
34- # the actual context will be provided at runtime.
35- self .wrapped = wrapped
36- self .model = wrapped .model
37- self .max_attempts = max_attempts
38-
39- def create_agent (self , context : Context ) -> AbstractAgent [AgentDepsT , OutputDataT ]:
40- """Create an agent instance with the given Restate context.
41-
42- Use this method to create an agent that is tied to a specific Restate context.
43- With this agent, all operations will be executed within the provided context,
44- enabling features like retries and durable steps.
45- Note that the agent will automatically wrap tool calls with restate's `ctx.run()`.
46-
47- Example:
48- ...
49- agent_provider = RestateAgentProvider(weather_agent)
50-
51- weather = restate.Service('weather')
52-
53- @weather.handler()
54- async def get_weather(ctx: restate.Context, city: str):
55- agent = agent_provider.create_agent_from_context(ctx)
56- result = await agent.run(f'What is the weather in {city}?')
57- return result.output
58- ...
59-
60- Args:
61- context: The Restate context to use for the agent.
62- auto_wrap_tool_calls: Whether to automatically wrap tool calls with restate's ctx.run() (durable step), True by default.
63-
64- Returns:
65- A RestateAgent instance that uses the provided context for its operations.
66- """
67-
68- def get_context (_unused : AgentDepsT ) -> Context :
69- return context
70-
71- builder = self
72- return RestateAgent (builder = builder , get_context = get_context , auto_wrap_tools = True )
73-
74- def create_agent_with_advanced_tools (
75- self , get_context : Callable [[AgentDepsT ], Context ]
76- ) -> AbstractAgent [AgentDepsT , OutputDataT ]:
77- """Create an agent instance that is able to obtain Restate context from its dependencies.
27+ class RestateAgent (WrapperAgent [AgentDepsT , OutputDataT ]):
28+ """An agent that integrates with Restate framework for building resilient applications.
7829
79- Use this method, if you are planning to use restate's context inside the tools (for rpc, timers, multi step tools etc.)
80- To obtain a context inside a tool you can add a dependency that has a `restate_context` attribute, and provide a `get_context` extractor
81- function that extracts the context from the dependencies at runtime .
30+ This agent wraps an existing agent with Restate context capabilities, providing
31+ automatic retries and durable execution for all operations. By default, tool calls
32+ are automatically wrapped with Restate's execution model .
8233
83- Note: that the agent will NOT automatically wrap tool calls with restate's `ctx.run()`
84- since the tools may use the context in different ways .
34+ Example:
35+ .. .
8536
86- Example:
87- ...
37+ weather = restate.Service('weather')
8838
89- @dataclass
90- WeatherDeps:
91- ...
92- restate_context: Context
39+ @weather.handler()
40+ async def get_weather(ctx: restate.Context, city: str):
41+ agent = RestateAgent(weather_agent, context=ctx)
42+ result = await agent.run(f'What is the weather in {city}?')
43+ return result.output
44+ ...
9345
94- weather_agent = Agent(..., deps_type=WeatherDeps, ...)
46+ For advanced scenarios, you can disable automatic tool wrapping by setting
47+ `disable_auto_wrapping_tools=True`. This allows direct usage of Restate context
48+ within your tools for features like RPC calls, timers, and multi-step operations.
9549
96- @weather_agent.tool
97- async def get_lat_lng(ctx: RunContext[WeatherDeps], location_description: str) -> LatLng:
98- restate_context = ctx.deps.restate_context
99- lat = await restate_context.run(...) # <---- note the direct usage of the restate context
100- lng = await restate_context.run(...)
101- return LatLng(lat, lng)
50+ When automatic wrapping is disabled, function tools will NOT be automatically executed
51+ within Restate's `ctx.run()` context, giving you full control over how the
52+ Restate context is used within your tool implementations.
53+ But model calls, and MCP tool calls will still be automatically wrapped.
10254
103- agent = RestateAgentProvider(weather_agent).create_agent_from_deps(lambda deps: deps.restate_context)
55+ Example:
56+ ...
10457
105- weather = restate.Service('weather')
58+ @dataclass
59+ WeatherDeps:
60+ ...
61+ restate_context: Context
10662
107- @weather.handler()
108- async def get_weather(ctx: restate.Context, city: str):
109- result = await agent.run(f'What is the weather in {city}?', deps=WeatherDeps(restate_context=ctx, ...))
110- return result.output
111- ...
63+ weather_agent = Agent(..., deps_type=WeatherDeps, ...)
11264
113- Args:
114- get_context: A callable that extracts the Restate context from the agent's dependencies.
65+ @weather_agent.tool
66+ async def get_lat_lng(ctx: RunContext[WeatherDeps], location_description: str) -> LatLng:
67+ restate_context = ctx.deps.restate_context
68+ lat = await restate_context.run(...) # <---- note the direct usage of the restate context
69+ lng = await restate_context.run(...)
70+ return LatLng(lat, lng)
11571
116- Returns:
117- A RestateAgent instance that uses the provided context extractor to obtain the Restate context at runtime.
11872
119- """
120- builder = self
121- return RestateAgent (builder = builder , get_context = get_context , auto_wrap_tools = False )
73+ weather = restate.Service('weather')
12274
75+ @weather.handler()
76+ async def get_weather(ctx: restate.Context, city: str):
77+ agent = RestateAgent(weather_agent, context=ctx)
78+ result = await agent.run(f'What is the weather in {city}?', deps=WeatherDeps(restate_context=ctx, ...))
79+ return result.output
80+ ...
12381
124- class RestateAgent (WrapperAgent [AgentDepsT , OutputDataT ]):
125- """An agent that integrates with the Restate framework for resilient applications."""
82+ """
12683
12784 def __init__ (
12885 self ,
129- builder : RestateAgentProvider [AgentDepsT , OutputDataT ],
130- get_context : Callable [[AgentDepsT ], Context ],
131- auto_wrap_tools : bool ,
86+ wrapped : AbstractAgent [AgentDepsT , OutputDataT ],
87+ restate_context : Context ,
88+ * ,
89+ disable_auto_wrapping_tools : bool = False ,
13290 ):
133- super ().__init__ (builder .wrapped )
134- self ._builder = builder
135- self ._get_context = get_context
136- self ._auto_wrap_tools = auto_wrap_tools
137-
138- @contextmanager
139- def _restate_overrides (self , context : Context ) -> Iterator [None ]:
140- model = RestateModelWrapper (self ._builder .model , context , max_attempts = self ._builder .max_attempts )
91+ super ().__init__ (wrapped )
92+ if not isinstance (wrapped .model , Model ):
93+ raise TerminalError (
94+ 'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.'
95+ )
96+ self ._model = RestateModelWrapper (wrapped .model , restate_context , max_attempts = 3 )
14197
14298 def set_context (toolset : AbstractToolset [AgentDepsT ]) -> AbstractToolset [AgentDepsT ]:
14399 """Set the Restate context for the toolset, wrapping tools if needed."""
144- if isinstance (toolset , FunctionToolset ) and self . _auto_wrap_tools :
145- return RestateContextRunToolSet (toolset , context )
100+ if isinstance (toolset , FunctionToolset ) and not disable_auto_wrapping_tools :
101+ return RestateContextRunToolSet (toolset , restate_context )
146102 try :
147103 from pydantic_ai .mcp import MCPServer
148104
@@ -151,14 +107,16 @@ def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDe
151107 pass
152108 else :
153109 if isinstance (toolset , MCPServer ):
154- return RestateMCPServer (toolset , context )
110+ return RestateMCPServer (toolset , restate_context )
155111
156112 return toolset
157113
158- toolsets = [toolset .visit_and_replace (set_context ) for toolset in self . _builder . wrapped .toolsets ]
114+ self . _toolsets = [toolset .visit_and_replace (set_context ) for toolset in wrapped .toolsets ]
159115
116+ @contextmanager
117+ def _restate_overrides (self ) -> Iterator [None ]:
160118 with (
161- super ().override (model = model , toolsets = toolsets , tools = []),
119+ super ().override (model = self . _model , toolsets = self . _toolsets , tools = []),
162120 self .sequential_tool_calls (),
163121 ):
164122 yield
@@ -255,8 +213,7 @@ async def main():
255213 raise TerminalError (
256214 'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.'
257215 )
258- context = self ._get_context (deps )
259- with self ._restate_overrides (context ):
216+ with self ._restate_overrides ():
260217 return await super (WrapperAgent , self ).run (
261218 user_prompt = user_prompt ,
262219 output_type = output_type ,
0 commit comments