Skip to content

Commit a58dd47

Browse files
ajac-zeroDouweM
andauthored
Add Agent.run_stream_sync method and sync convenience methods on StreamedRunResult (#3146)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 063278e commit a58dd47

File tree

5 files changed

+526
-5
lines changed

5 files changed

+526
-5
lines changed

docs/agents.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ There are five ways to run an agent:
6565

6666
1. [`agent.run()`][pydantic_ai.agent.AbstractAgent.run] — an async function which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response.
6767
2. [`agent.run_sync()`][pydantic_ai.agent.AbstractAgent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`).
68-
3. [`agent.run_stream()`][pydantic_ai.agent.AbstractAgent.run_stream] — an async context manager which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream text and structured output as an async iterable.
68+
3. [`agent.run_stream()`][pydantic_ai.agent.AbstractAgent.run_stream] — an async context manager which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream text and structured output as an async iterable. [`agent.run_stream_sync()`][pydantic_ai.agent.AbstractAgent.run_stream_sync] is a synchronous variation that returns a [`StreamedRunResultSync`][pydantic_ai.result.StreamedRunResultSync] with synchronous versions of the same methods.
6969
4. [`agent.run_stream_events()`][pydantic_ai.agent.AbstractAgent.run_stream_events] — a function which returns an async iterable of [`AgentStreamEvent`s][pydantic_ai.messages.AgentStreamEvent] and a [`AgentRunResultEvent`][pydantic_ai.run.AgentRunResultEvent] containing the final run result.
7070
5. [`agent.iter()`][pydantic_ai.Agent.iter] — a context manager which returns an [`AgentRun`][pydantic_ai.agent.AgentRun], an async iterable over the nodes of the agent's underlying [`Graph`][pydantic_graph.graph.Graph].
7171

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,15 @@ def sync_anext(iterator: Iterator[T]) -> T:
234234
raise StopAsyncIteration() from e
235235

236236

237+
def sync_async_iterator(async_iter: AsyncIterator[T]) -> Iterator[T]:
238+
loop = get_event_loop()
239+
while True:
240+
try:
241+
yield loop.run_until_complete(anext(async_iter))
242+
except StopAsyncIteration:
243+
break
244+
245+
237246
def now_utc() -> datetime:
238247
return datetime.now(tz=timezone.utc)
239248

@@ -489,3 +498,12 @@ def get_union_args(tp: Any) -> tuple[Any, ...]:
489498
return tuple(_unwrap_annotated(arg) for arg in get_args(tp))
490499
else:
491500
return ()
501+
502+
503+
def get_event_loop():
504+
try:
505+
event_loop = asyncio.get_event_loop()
506+
except RuntimeError: # pragma: lax no cover
507+
event_loop = asyncio.new_event_loop()
508+
asyncio.set_event_loop(event_loop)
509+
return event_loop

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 129 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from typing_extensions import Self, TypeIs, TypeVar
1313

1414
from pydantic_graph import End
15-
from pydantic_graph._utils import get_event_loop
1615

1716
from .. import (
1817
_agent_graph,
@@ -335,7 +334,7 @@ def run_sync(
335334
if infer_name and self.name is None:
336335
self._infer_name(inspect.currentframe())
337336

338-
return get_event_loop().run_until_complete(
337+
return _utils.get_event_loop().run_until_complete(
339338
self.run(
340339
user_prompt,
341340
output_type=output_type,
@@ -581,6 +580,133 @@ async def on_complete() -> None:
581580
if not yielded:
582581
raise exceptions.AgentRunError('Agent run finished without producing a final result') # pragma: no cover
583582

583+
@overload
584+
def run_stream_sync(
585+
self,
586+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
587+
*,
588+
output_type: None = None,
589+
message_history: Sequence[_messages.ModelMessage] | None = None,
590+
deferred_tool_results: DeferredToolResults | None = None,
591+
model: models.Model | models.KnownModelName | str | None = None,
592+
deps: AgentDepsT = None,
593+
model_settings: ModelSettings | None = None,
594+
usage_limits: _usage.UsageLimits | None = None,
595+
usage: _usage.RunUsage | None = None,
596+
infer_name: bool = True,
597+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
598+
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
599+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
600+
) -> result.StreamedRunResultSync[AgentDepsT, OutputDataT]: ...
601+
602+
@overload
603+
def run_stream_sync(
604+
self,
605+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
606+
*,
607+
output_type: OutputSpec[RunOutputDataT],
608+
message_history: Sequence[_messages.ModelMessage] | None = None,
609+
deferred_tool_results: DeferredToolResults | None = None,
610+
model: models.Model | models.KnownModelName | str | None = None,
611+
deps: AgentDepsT = None,
612+
model_settings: ModelSettings | None = None,
613+
usage_limits: _usage.UsageLimits | None = None,
614+
usage: _usage.RunUsage | None = None,
615+
infer_name: bool = True,
616+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
617+
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
618+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
619+
) -> result.StreamedRunResultSync[AgentDepsT, RunOutputDataT]: ...
620+
621+
def run_stream_sync(
622+
self,
623+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
624+
*,
625+
output_type: OutputSpec[RunOutputDataT] | None = None,
626+
message_history: Sequence[_messages.ModelMessage] | None = None,
627+
deferred_tool_results: DeferredToolResults | None = None,
628+
model: models.Model | models.KnownModelName | str | None = None,
629+
deps: AgentDepsT = None,
630+
model_settings: ModelSettings | None = None,
631+
usage_limits: _usage.UsageLimits | None = None,
632+
usage: _usage.RunUsage | None = None,
633+
infer_name: bool = True,
634+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
635+
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
636+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
637+
) -> result.StreamedRunResultSync[AgentDepsT, Any]:
638+
"""Run the agent with a user prompt in sync streaming mode.
639+
640+
This is a convenience method that wraps [`run_stream()`][pydantic_ai.agent.AbstractAgent.run_stream] with `loop.run_until_complete(...)`.
641+
You therefore can't use this method inside async code or if there's an active event loop.
642+
643+
This method builds an internal agent graph (using system prompts, tools and output schemas) and then
644+
runs the graph until the model produces output matching the `output_type`, for example text or structured data.
645+
At this point, a streaming run result object is yielded from which you can stream the output as it comes in,
646+
and -- once this output has completed streaming -- get the complete output, message history, and usage.
647+
648+
As this method will consider the first output matching the `output_type` to be the final output,
649+
it will stop running the agent graph and will not execute any tool calls made by the model after this "final" output.
650+
If you want to always run the agent graph to completion and stream events and output at the same time,
651+
use [`agent.run()`][pydantic_ai.agent.AbstractAgent.run] with an `event_stream_handler` or [`agent.iter()`][pydantic_ai.agent.AbstractAgent.iter] instead.
652+
653+
Example:
654+
```python
655+
from pydantic_ai import Agent
656+
657+
agent = Agent('openai:gpt-4o')
658+
659+
def main():
660+
response = agent.run_stream_sync('What is the capital of the UK?')
661+
print(response.get_output())
662+
#> The capital of the UK is London.
663+
```
664+
665+
Args:
666+
user_prompt: User input to start/continue the conversation.
667+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
668+
output validators since output validators would expect an argument that matches the agent's output type.
669+
message_history: History of the conversation so far.
670+
deferred_tool_results: Optional results for deferred tool calls in the message history.
671+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
672+
deps: Optional dependencies to use for this run.
673+
model_settings: Optional settings to use for this model's request.
674+
usage_limits: Optional limits on model request count or token usage.
675+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
676+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
677+
toolsets: Optional additional toolsets for this run.
678+
builtin_tools: Optional additional builtin tools for this run.
679+
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run.
680+
It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager.
681+
Note that it does _not_ receive any events after the final result is found.
682+
683+
Returns:
684+
The result of the run.
685+
"""
686+
if infer_name and self.name is None:
687+
self._infer_name(inspect.currentframe())
688+
689+
async def _consume_stream():
690+
async with self.run_stream(
691+
user_prompt,
692+
output_type=output_type,
693+
message_history=message_history,
694+
deferred_tool_results=deferred_tool_results,
695+
model=model,
696+
deps=deps,
697+
model_settings=model_settings,
698+
usage_limits=usage_limits,
699+
usage=usage,
700+
infer_name=infer_name,
701+
toolsets=toolsets,
702+
builtin_tools=builtin_tools,
703+
event_stream_handler=event_stream_handler,
704+
) as stream_result:
705+
yield stream_result
706+
707+
async_result = _utils.get_event_loop().run_until_complete(anext(_consume_stream()))
708+
return result.StreamedRunResultSync(async_result)
709+
584710
@overload
585711
def run_stream_events(
586712
self,
@@ -1217,6 +1343,6 @@ def to_cli_sync(
12171343
agent.to_cli_sync(prog_name='assistant')
12181344
```
12191345
"""
1220-
return get_event_loop().run_until_complete(
1346+
return _utils.get_event_loop().run_until_complete(
12211347
self.to_cli(deps=deps, prog_name=prog_name, message_history=message_history)
12221348
)

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations as _annotations
22

3-
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
3+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator
44
from copy import deepcopy
55
from dataclasses import dataclass, field
66
from datetime import datetime
@@ -35,6 +35,7 @@
3535
'OutputDataT_inv',
3636
'ToolOutput',
3737
'OutputValidatorFunc',
38+
'StreamedRunResultSync',
3839
)
3940

4041

@@ -555,6 +556,158 @@ async def _marked_completed(self, message: _messages.ModelResponse | None = None
555556
await self._on_complete()
556557

557558

559+
@dataclass(init=False)
560+
class StreamedRunResultSync(Generic[AgentDepsT, OutputDataT]):
561+
"""Synchronous wrapper for [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] that only exposes sync methods."""
562+
563+
_streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]
564+
565+
def __init__(self, streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]) -> None:
566+
self._streamed_run_result = streamed_run_result
567+
568+
def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
569+
"""Return the history of messages.
570+
571+
Args:
572+
output_tool_return_content: The return content of the tool call to set in the last message.
573+
This provides a convenient way to modify the content of the output tool call if you want to continue
574+
the conversation and want to set the response to the output tool call. If `None`, the last message will
575+
not be modified.
576+
577+
Returns:
578+
List of messages.
579+
"""
580+
return self._streamed_run_result.all_messages(output_tool_return_content=output_tool_return_content)
581+
582+
def all_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover
583+
"""Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResultSync.all_messages] as JSON bytes.
584+
585+
Args:
586+
output_tool_return_content: The return content of the tool call to set in the last message.
587+
This provides a convenient way to modify the content of the output tool call if you want to continue
588+
the conversation and want to set the response to the output tool call. If `None`, the last message will
589+
not be modified.
590+
591+
Returns:
592+
JSON bytes representing the messages.
593+
"""
594+
return self._streamed_run_result.all_messages_json(output_tool_return_content=output_tool_return_content)
595+
596+
def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
597+
"""Return new messages associated with this run.
598+
599+
Messages from older runs are excluded.
600+
601+
Args:
602+
output_tool_return_content: The return content of the tool call to set in the last message.
603+
This provides a convenient way to modify the content of the output tool call if you want to continue
604+
the conversation and want to set the response to the output tool call. If `None`, the last message will
605+
not be modified.
606+
607+
Returns:
608+
List of new messages.
609+
"""
610+
return self._streamed_run_result.new_messages(output_tool_return_content=output_tool_return_content)
611+
612+
def new_messages_json(self, *, output_tool_return_content: str | None = None) -> bytes: # pragma: no cover
613+
"""Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResultSync.new_messages] as JSON bytes.
614+
615+
Args:
616+
output_tool_return_content: The return content of the tool call to set in the last message.
617+
This provides a convenient way to modify the content of the output tool call if you want to continue
618+
the conversation and want to set the response to the output tool call. If `None`, the last message will
619+
not be modified.
620+
621+
Returns:
622+
JSON bytes representing the new messages.
623+
"""
624+
return self._streamed_run_result.new_messages_json(output_tool_return_content=output_tool_return_content)
625+
626+
def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDataT]:
627+
"""Stream the output as an iterable.
628+
629+
The pydantic validator for structured data will be called in
630+
[partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation)
631+
on each iteration.
632+
633+
Args:
634+
debounce_by: by how much (if at all) to debounce/group the output chunks by. `None` means no debouncing.
635+
Debouncing is particularly important for long structured outputs to reduce the overhead of
636+
performing validation as each token is received.
637+
638+
Returns:
639+
An iterable of the response data.
640+
"""
641+
return _utils.sync_async_iterator(self._streamed_run_result.stream_output(debounce_by=debounce_by))
642+
643+
def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> Iterator[str]:
644+
"""Stream the text result as an iterable.
645+
646+
!!! note
647+
Result validators will NOT be called on the text result if `delta=True`.
648+
649+
Args:
650+
delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text
651+
up to the current point.
652+
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
653+
Debouncing is particularly important for long structured responses to reduce the overhead of
654+
performing validation as each token is received.
655+
"""
656+
return _utils.sync_async_iterator(self._streamed_run_result.stream_text(delta=delta, debounce_by=debounce_by))
657+
658+
def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple[_messages.ModelResponse, bool]]:
659+
"""Stream the response as an iterable of Structured LLM Messages.
660+
661+
Args:
662+
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
663+
Debouncing is particularly important for long structured responses to reduce the overhead of
664+
performing validation as each token is received.
665+
666+
Returns:
667+
An iterable of the structured response message and whether that is the last message.
668+
"""
669+
return _utils.sync_async_iterator(self._streamed_run_result.stream_responses(debounce_by=debounce_by))
670+
671+
def get_output(self) -> OutputDataT:
672+
"""Stream the whole response, validate and return it."""
673+
return _utils.get_event_loop().run_until_complete(self._streamed_run_result.get_output())
674+
675+
@property
676+
def response(self) -> _messages.ModelResponse:
677+
"""Return the current state of the response."""
678+
return self._streamed_run_result.response
679+
680+
def usage(self) -> RunUsage:
681+
"""Return the usage of the whole run.
682+
683+
!!! note
684+
This won't return the full usage until the stream is finished.
685+
"""
686+
return self._streamed_run_result.usage()
687+
688+
def timestamp(self) -> datetime:
689+
"""Get the timestamp of the response."""
690+
return self._streamed_run_result.timestamp()
691+
692+
def validate_response_output(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT:
693+
"""Validate a structured result message."""
694+
return _utils.get_event_loop().run_until_complete(
695+
self._streamed_run_result.validate_response_output(message, allow_partial=allow_partial)
696+
)
697+
698+
@property
699+
def is_complete(self) -> bool:
700+
"""Whether the stream has all been received.
701+
702+
This is set to `True` when one of
703+
[`stream_output`][pydantic_ai.result.StreamedRunResultSync.stream_output],
704+
[`stream_text`][pydantic_ai.result.StreamedRunResultSync.stream_text],
705+
[`stream_responses`][pydantic_ai.result.StreamedRunResultSync.stream_responses] or
706+
[`get_output`][pydantic_ai.result.StreamedRunResultSync.get_output] completes.
707+
"""
708+
return self._streamed_run_result.is_complete
709+
710+
558711
@dataclass(repr=False)
559712
class FinalResult(Generic[OutputDataT]):
560713
"""Marker class storing the final output of an agent run and associated metadata."""

0 commit comments

Comments
 (0)