Skip to content

Commit 01dd2dd

Browse files
committed
Support event_stream_handler
1 parent 8f2e08f commit 01dd2dd

File tree

2 files changed

+228
-12
lines changed

2 files changed

+228
-12
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/restate/_agent.py

Lines changed: 143 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterator, Sequence
4-
from contextlib import contextmanager
5-
from typing import Any, overload
3+
from collections.abc import AsyncIterable, AsyncIterator, Iterator, Sequence
4+
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
5+
from typing import Any, Never, overload
66

77
from restate import Context, TerminalError
88

@@ -11,12 +11,14 @@
1111
from pydantic_ai.agent.abstract import AbstractAgent, EventStreamHandler, RunOutputDataT
1212
from pydantic_ai.agent.wrapper import WrapperAgent
1313
from pydantic_ai.builtin_tools import AbstractBuiltinTool
14-
from pydantic_ai.messages import ModelMessage, UserContent
14+
from pydantic_ai.exceptions import UserError
15+
from pydantic_ai.messages import AgentStreamEvent, ModelMessage, UserContent
1516
from pydantic_ai.models import Model
1617
from pydantic_ai.output import OutputDataT, OutputSpec
18+
from pydantic_ai.result import StreamedRunResult
1719
from pydantic_ai.run import AgentRunResult
1820
from pydantic_ai.settings import ModelSettings
19-
from pydantic_ai.tools import DeferredToolResults
21+
from pydantic_ai.tools import DeferredToolResults, RunContext
2022
from pydantic_ai.toolsets.abstract import AbstractToolset
2123
from pydantic_ai.toolsets.function import FunctionToolset
2224
from pydantic_ai.usage import RunUsage, UsageLimits
@@ -87,14 +89,21 @@ def __init__(
8789
wrapped: AbstractAgent[AgentDepsT, OutputDataT],
8890
restate_context: Context,
8991
*,
92+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
9093
disable_auto_wrapping_tools: bool = False,
9194
):
9295
super().__init__(wrapped)
9396
if not isinstance(wrapped.model, Model):
9497
raise TerminalError(
9598
'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.'
9699
)
97-
self._model = RestateModelWrapper(wrapped.model, restate_context, max_attempts=3)
100+
101+
self.restate_context = restate_context
102+
self._event_stream_handler = event_stream_handler
103+
self._disable_auto_wrapping_tools = disable_auto_wrapping_tools
104+
self._model = RestateModelWrapper(
105+
wrapped.model, restate_context, event_stream_handler=event_stream_handler, max_attempts=3
106+
)
98107

99108
def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
100109
"""Set the Restate context for the toolset, wrapping tools if needed."""
@@ -122,6 +131,37 @@ def _restate_overrides(self) -> Iterator[None]:
122131
):
123132
yield
124133

134+
@property
135+
def model(self) -> models.Model | models.KnownModelName | str | None:
136+
return self._model
137+
138+
@property
139+
def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None:
140+
handler = self._event_stream_handler or super().event_stream_handler
141+
if handler is None:
142+
return None
143+
if self._disable_auto_wrapping_tools:
144+
return handler
145+
return self.wrapped_event_stream_handler
146+
147+
async def wrapped_event_stream_handler(
148+
self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[AgentStreamEvent]
149+
) -> None:
150+
fn = self._event_stream_handler
151+
if fn is None:
152+
return
153+
async for event in stream:
154+
155+
async def single_event():
156+
yield event
157+
158+
await self.restate_context.run_typed('run event', lambda: fn(ctx, single_event()))
159+
160+
@property
161+
def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
162+
with self._restate_overrides():
163+
return super().toolsets
164+
125165
@overload
126166
async def run(
127167
self,
@@ -232,3 +272,100 @@ async def main():
232272
toolsets=toolsets,
233273
event_stream_handler=event_stream_handler,
234274
)
275+
276+
@overload
277+
def run_stream(
278+
self,
279+
user_prompt: str | Sequence[UserContent] | None = None,
280+
*,
281+
output_type: None = None,
282+
message_history: Sequence[ModelMessage] | None = None,
283+
deferred_tool_results: DeferredToolResults | None = None,
284+
model: models.Model | models.KnownModelName | str | None = None,
285+
deps: AgentDepsT = None,
286+
model_settings: ModelSettings | None = None,
287+
usage_limits: UsageLimits | None = None,
288+
usage: RunUsage | None = None,
289+
infer_name: bool = True,
290+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
291+
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
292+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
293+
) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ...
294+
295+
@overload
296+
def run_stream(
297+
self,
298+
user_prompt: str | Sequence[UserContent] | None = None,
299+
*,
300+
output_type: OutputSpec[RunOutputDataT],
301+
message_history: Sequence[ModelMessage] | None = None,
302+
deferred_tool_results: DeferredToolResults | None = None,
303+
model: models.Model | models.KnownModelName | str | None = None,
304+
deps: AgentDepsT = None,
305+
model_settings: ModelSettings | None = None,
306+
usage_limits: UsageLimits | None = None,
307+
usage: RunUsage | None = None,
308+
infer_name: bool = True,
309+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
310+
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
311+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
312+
) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
313+
314+
@asynccontextmanager
315+
async def run_stream(
316+
self,
317+
user_prompt: str | Sequence[UserContent] | None = None,
318+
*,
319+
output_type: OutputSpec[RunOutputDataT] | None = None,
320+
message_history: Sequence[ModelMessage] | None = None,
321+
deferred_tool_results: DeferredToolResults | None = None,
322+
model: models.Model | models.KnownModelName | str | None = None,
323+
deps: AgentDepsT = None,
324+
model_settings: ModelSettings | None = None,
325+
usage_limits: UsageLimits | None = None,
326+
usage: RunUsage | None = None,
327+
infer_name: bool = True,
328+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
329+
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
330+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
331+
**_deprecated_kwargs: Never,
332+
) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]]:
333+
"""Run the agent with a user prompt in async mode, returning a streamed response.
334+
335+
Example:
336+
```python
337+
from pydantic_ai import Agent
338+
339+
agent = Agent('openai:gpt-4o')
340+
341+
async def main():
342+
async with agent.run_stream('What is the capital of the UK?') as response:
343+
print(await response.get_output())
344+
#> The capital of the UK is London.
345+
```
346+
347+
Args:
348+
user_prompt: User input to start/continue the conversation.
349+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
350+
output validators since output validators would expect an argument that matches the agent's output type.
351+
message_history: History of the conversation so far.
352+
deferred_tool_results: Optional results for deferred tool calls in the message history.
353+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
354+
deps: Optional dependencies to use for this run.
355+
model_settings: Optional settings to use for this model's request.
356+
usage_limits: Optional limits on model request count or token usage.
357+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
358+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
359+
toolsets: Optional additional toolsets for this run.
360+
builtin_tools: Optional additional builtin tools for this run.
361+
event_stream_handler: Optional event stream handler to use for this run. It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager.
362+
363+
Returns:
364+
The result of the run.
365+
"""
366+
raise UserError(
367+
'`agent.run_stream()` cannot be used inside a restate handler. '
368+
'Set an `event_stream_handler` on the agent and use `agent.run()` instead.'
369+
)
370+
371+
yield
Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,99 @@
1+
from collections.abc import AsyncIterator
2+
from contextlib import asynccontextmanager
3+
from datetime import datetime
14
from typing import Any
25

36
from restate import Context, RunOptions
47

8+
from pydantic_ai.agent.abstract import EventStreamHandler
59
from pydantic_ai.durable_exec.restate._serde import PydanticTypeAdapter
6-
from pydantic_ai.messages import ModelResponse
7-
from pydantic_ai.models import Model
10+
from pydantic_ai.exceptions import UserError
11+
from pydantic_ai.messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
12+
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
813
from pydantic_ai.models.wrapper import WrapperModel
14+
from pydantic_ai.settings import ModelSettings
15+
from pydantic_ai.tools import AgentDepsT, RunContext
16+
from pydantic_ai.usage import RequestUsage
917

1018
MODEL_RESPONSE_SERDE = PydanticTypeAdapter(ModelResponse)
1119

1220

21+
class RestateStreamedResponse(StreamedResponse):
22+
def __init__(self, model_request_parameters: ModelRequestParameters, response: ModelResponse):
23+
super().__init__(model_request_parameters)
24+
self.response = response
25+
26+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
27+
return
28+
# noinspection PyUnreachableCode
29+
yield
30+
31+
def get(self) -> ModelResponse:
32+
return self.response
33+
34+
def usage(self) -> RequestUsage:
35+
return self.response.usage # pragma: no cover
36+
37+
@property
38+
def model_name(self) -> str:
39+
return self.response.model_name or '' # pragma: no cover
40+
41+
@property
42+
def provider_name(self) -> str:
43+
return self.response.provider_name or '' # pragma: no cover
44+
45+
@property
46+
def timestamp(self) -> datetime:
47+
return self.response.timestamp # pragma: no cover
48+
49+
1350
class RestateModelWrapper(WrapperModel):
14-
def __init__(self, wrapped: Model, context: Context, max_attempts: int | None = None):
51+
def __init__(
52+
self,
53+
wrapped: Model,
54+
context: Context,
55+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
56+
max_attempts: int | None = None,
57+
):
1558
super().__init__(wrapped)
16-
self.options = RunOptions(serde=MODEL_RESPONSE_SERDE, max_attempts=max_attempts)
17-
self.context = context
59+
self._options = RunOptions(serde=MODEL_RESPONSE_SERDE, max_attempts=max_attempts)
60+
self._context = context
61+
self._event_stream_handler = event_stream_handler
1862

1963
async def request(self, *args: Any, **kwargs: Any) -> ModelResponse:
20-
return await self.context.run_typed('Model call', self.wrapped.request, self.options, *args, **kwargs)
64+
return await self._context.run_typed('Model call', self.wrapped.request, self._options, *args, **kwargs)
65+
66+
@asynccontextmanager
67+
async def request_stream(
68+
self,
69+
messages: list[ModelMessage],
70+
model_settings: ModelSettings | None,
71+
model_request_parameters: ModelRequestParameters,
72+
run_context: RunContext[AgentDepsT] | None = None,
73+
) -> AsyncIterator[StreamedResponse]:
74+
if run_context is None:
75+
raise UserError(
76+
'A model cannot be used with `pydantic_ai.direct.model_request_stream()` as it requires a `run_context`. Set an `event_stream_handler` on the agent and use `agent.run()` instead.'
77+
)
78+
79+
# We can never get here without an `event_stream_handler`, as `TemporalAgent.run_stream` and `TemporalAgent.iter` raise an error saying to use `TemporalAgent.run` instead,
80+
# and that only calls `request_stream` if `event_stream_handler` is set.
81+
fn = self._event_stream_handler
82+
assert fn is not None
83+
84+
async def request_stream_run():
85+
async with self.wrapped.request_stream(
86+
messages,
87+
model_settings,
88+
model_request_parameters,
89+
run_context,
90+
) as streamed_response:
91+
await fn(run_context, streamed_response)
92+
93+
async for _ in streamed_response:
94+
pass
95+
return streamed_response.get()
96+
97+
response = await self._context.run_typed('Model stream call', request_stream_run, self._options)
98+
99+
yield RestateStreamedResponse(model_request_parameters, response)

0 commit comments

Comments
 (0)