|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from collections.abc import Iterator, Sequence |
4 | | -from contextlib import contextmanager |
5 | | -from typing import Any, overload |
| 3 | +from collections.abc import AsyncIterable, AsyncIterator, Iterator, Sequence |
| 4 | +from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager |
| 5 | +from typing import Any, Never, overload |
6 | 6 |
|
7 | 7 | from restate import Context, TerminalError |
8 | 8 |
|
|
11 | 11 | from pydantic_ai.agent.abstract import AbstractAgent, EventStreamHandler, RunOutputDataT |
12 | 12 | from pydantic_ai.agent.wrapper import WrapperAgent |
13 | 13 | from pydantic_ai.builtin_tools import AbstractBuiltinTool |
14 | | -from pydantic_ai.messages import ModelMessage, UserContent |
| 14 | +from pydantic_ai.exceptions import UserError |
| 15 | +from pydantic_ai.messages import AgentStreamEvent, ModelMessage, UserContent |
15 | 16 | from pydantic_ai.models import Model |
16 | 17 | from pydantic_ai.output import OutputDataT, OutputSpec |
| 18 | +from pydantic_ai.result import StreamedRunResult |
17 | 19 | from pydantic_ai.run import AgentRunResult |
18 | 20 | from pydantic_ai.settings import ModelSettings |
19 | | -from pydantic_ai.tools import DeferredToolResults |
| 21 | +from pydantic_ai.tools import DeferredToolResults, RunContext |
20 | 22 | from pydantic_ai.toolsets.abstract import AbstractToolset |
21 | 23 | from pydantic_ai.toolsets.function import FunctionToolset |
22 | 24 | from pydantic_ai.usage import RunUsage, UsageLimits |
@@ -87,14 +89,21 @@ def __init__( |
87 | 89 | wrapped: AbstractAgent[AgentDepsT, OutputDataT], |
88 | 90 | restate_context: Context, |
89 | 91 | *, |
| 92 | + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, |
90 | 93 | disable_auto_wrapping_tools: bool = False, |
91 | 94 | ): |
92 | 95 | super().__init__(wrapped) |
93 | 96 | if not isinstance(wrapped.model, Model): |
94 | 97 | raise TerminalError( |
95 | 98 | 'An agent needs to have a `model` in order to be used with Restate, it cannot be set at agent run time.' |
96 | 99 | ) |
97 | | - self._model = RestateModelWrapper(wrapped.model, restate_context, max_attempts=3) |
| 100 | + |
| 101 | + self.restate_context = restate_context |
| 102 | + self._event_stream_handler = event_stream_handler |
| 103 | + self._disable_auto_wrapping_tools = disable_auto_wrapping_tools |
| 104 | + self._model = RestateModelWrapper( |
| 105 | + wrapped.model, restate_context, event_stream_handler=event_stream_handler, max_attempts=3 |
| 106 | + ) |
98 | 107 |
|
99 | 108 | def set_context(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]: |
100 | 109 | """Set the Restate context for the toolset, wrapping tools if needed.""" |
@@ -122,6 +131,37 @@ def _restate_overrides(self) -> Iterator[None]: |
122 | 131 | ): |
123 | 132 | yield |
124 | 133 |
|
| 134 | + @property |
| 135 | + def model(self) -> models.Model | models.KnownModelName | str | None: |
| 136 | + return self._model |
| 137 | + |
| 138 | + @property |
| 139 | + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: |
| 140 | + handler = self._event_stream_handler or super().event_stream_handler |
| 141 | + if handler is None: |
| 142 | + return None |
| 143 | + if self._disable_auto_wrapping_tools: |
| 144 | + return handler |
| 145 | + return self.wrapped_event_stream_handler |
| 146 | + |
| 147 | + async def wrapped_event_stream_handler( |
| 148 | + self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[AgentStreamEvent] |
| 149 | + ) -> None: |
| 150 | + fn = self._event_stream_handler |
| 151 | + if fn is None: |
| 152 | + return |
| 153 | + async for event in stream: |
| 154 | + |
| 155 | + async def single_event(): |
| 156 | + yield event |
| 157 | + |
| 158 | + await self.restate_context.run_typed('run event', lambda: fn(ctx, single_event())) |
| 159 | + |
| 160 | + @property |
| 161 | + def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]: |
| 162 | + with self._restate_overrides(): |
| 163 | + return super().toolsets |
| 164 | + |
125 | 165 | @overload |
126 | 166 | async def run( |
127 | 167 | self, |
@@ -232,3 +272,100 @@ async def main(): |
232 | 272 | toolsets=toolsets, |
233 | 273 | event_stream_handler=event_stream_handler, |
234 | 274 | ) |
| 275 | + |
| 276 | + @overload |
| 277 | + def run_stream( |
| 278 | + self, |
| 279 | + user_prompt: str | Sequence[UserContent] | None = None, |
| 280 | + *, |
| 281 | + output_type: None = None, |
| 282 | + message_history: Sequence[ModelMessage] | None = None, |
| 283 | + deferred_tool_results: DeferredToolResults | None = None, |
| 284 | + model: models.Model | models.KnownModelName | str | None = None, |
| 285 | + deps: AgentDepsT = None, |
| 286 | + model_settings: ModelSettings | None = None, |
| 287 | + usage_limits: UsageLimits | None = None, |
| 288 | + usage: RunUsage | None = None, |
| 289 | + infer_name: bool = True, |
| 290 | + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, |
| 291 | + builtin_tools: Sequence[AbstractBuiltinTool] | None = None, |
| 292 | + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, |
| 293 | + ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ... |
| 294 | + |
| 295 | + @overload |
| 296 | + def run_stream( |
| 297 | + self, |
| 298 | + user_prompt: str | Sequence[UserContent] | None = None, |
| 299 | + *, |
| 300 | + output_type: OutputSpec[RunOutputDataT], |
| 301 | + message_history: Sequence[ModelMessage] | None = None, |
| 302 | + deferred_tool_results: DeferredToolResults | None = None, |
| 303 | + model: models.Model | models.KnownModelName | str | None = None, |
| 304 | + deps: AgentDepsT = None, |
| 305 | + model_settings: ModelSettings | None = None, |
| 306 | + usage_limits: UsageLimits | None = None, |
| 307 | + usage: RunUsage | None = None, |
| 308 | + infer_name: bool = True, |
| 309 | + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, |
| 310 | + builtin_tools: Sequence[AbstractBuiltinTool] | None = None, |
| 311 | + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, |
| 312 | + ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... |
| 313 | + |
| 314 | + @asynccontextmanager |
| 315 | + async def run_stream( |
| 316 | + self, |
| 317 | + user_prompt: str | Sequence[UserContent] | None = None, |
| 318 | + *, |
| 319 | + output_type: OutputSpec[RunOutputDataT] | None = None, |
| 320 | + message_history: Sequence[ModelMessage] | None = None, |
| 321 | + deferred_tool_results: DeferredToolResults | None = None, |
| 322 | + model: models.Model | models.KnownModelName | str | None = None, |
| 323 | + deps: AgentDepsT = None, |
| 324 | + model_settings: ModelSettings | None = None, |
| 325 | + usage_limits: UsageLimits | None = None, |
| 326 | + usage: RunUsage | None = None, |
| 327 | + infer_name: bool = True, |
| 328 | + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, |
| 329 | + builtin_tools: Sequence[AbstractBuiltinTool] | None = None, |
| 330 | + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, |
| 331 | + **_deprecated_kwargs: Never, |
| 332 | + ) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]]: |
| 333 | + """Run the agent with a user prompt in async mode, returning a streamed response. |
| 334 | +
|
| 335 | + Example: |
| 336 | + ```python |
| 337 | + from pydantic_ai import Agent |
| 338 | +
|
| 339 | + agent = Agent('openai:gpt-4o') |
| 340 | +
|
| 341 | + async def main(): |
| 342 | + async with agent.run_stream('What is the capital of the UK?') as response: |
| 343 | + print(await response.get_output()) |
| 344 | + #> The capital of the UK is London. |
| 345 | + ``` |
| 346 | +
|
| 347 | + Args: |
| 348 | + user_prompt: User input to start/continue the conversation. |
| 349 | + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no |
| 350 | + output validators since output validators would expect an argument that matches the agent's output type. |
| 351 | + message_history: History of the conversation so far. |
| 352 | + deferred_tool_results: Optional results for deferred tool calls in the message history. |
| 353 | + model: Optional model to use for this run, required if `model` was not set when creating the agent. |
| 354 | + deps: Optional dependencies to use for this run. |
| 355 | + model_settings: Optional settings to use for this model's request. |
| 356 | + usage_limits: Optional limits on model request count or token usage. |
| 357 | + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. |
| 358 | + infer_name: Whether to try to infer the agent name from the call frame if it's not set. |
| 359 | + toolsets: Optional additional toolsets for this run. |
| 360 | + builtin_tools: Optional additional builtin tools for this run. |
| 361 | + event_stream_handler: Optional event stream handler to use for this run. It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager. |
| 362 | +
|
| 363 | + Returns: |
| 364 | + The result of the run. |
| 365 | + """ |
| 366 | + raise UserError( |
| 367 | + '`agent.run_stream()` cannot be used inside a restate handler. ' |
| 368 | + 'Set an `event_stream_handler` on the agent and use `agent.run()` instead.' |
| 369 | + ) |
| 370 | + |
| 371 | + yield |
0 commit comments