Skip to content
65 changes: 40 additions & 25 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dataclasses import field, replace
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast

import anyio
from opentelemetry.trace import Tracer
from typing_extensions import TypeVar, assert_never

Expand Down Expand Up @@ -663,8 +664,7 @@ async def _handle_tool_calls(
yield event

if output_final_result:
final_result = output_final_result[0]
self._next_node = self._handle_final_result(ctx, final_result, output_parts)
self._next_node = self._handle_final_result(ctx, output_final_result[0], output_parts)
else:
instructions = await ctx.deps.get_instructions(run_context)
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
Expand Down Expand Up @@ -882,7 +882,7 @@ async def process_tool_calls( # noqa: C901
output_final_result.append(final_result)


async def _call_tools(
async def _call_tools( # noqa: C901
tool_manager: ToolManager[DepsT],
tool_calls: list[_messages.ToolCallPart],
tool_call_results: dict[str, DeferredToolResult],
Expand Down Expand Up @@ -940,30 +940,45 @@ async def handle_call_or_result(

return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content)

if tool_manager.should_call_sequentially(tool_calls):
for index, call in enumerate(tool_calls):
if event := await handle_call_or_result(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
index,
):
yield event
send_stream, receive_stream = anyio.create_memory_object_stream[_messages.HandleResponseEvent]()

async def _run_tools():
async with send_stream:
assert tool_manager.ctx is not None, 'ToolManager.ctx needs to be set'
tool_manager.ctx.event_stream = send_stream

if tool_manager.should_call_sequentially(tool_calls):
for index, call in enumerate(tool_calls):
if event := await handle_call_or_result(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
index,
):
await send_stream.send(event)

else:
tasks = [
asyncio.create_task(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
name=call.tool_name,
)
for call in tool_calls
]

else:
tasks = [
asyncio.create_task(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
name=call.tool_name,
)
for call in tool_calls
]
pending = tasks
while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for task in done:
index = tasks.index(task)
if event := await handle_call_or_result(coro_or_task=task, index=index):
await send_stream.send(event)

task = asyncio.create_task(_run_tools())

async with receive_stream:
async for message in receive_stream:
yield message

pending = tasks
while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for task in done:
index = tasks.index(task)
if event := await handle_call_or_result(coro_or_task=task, index=index):
yield event
await task

# We append the results at the end, rather than as they are received, to retain a consistent ordering
# This is mostly just to simplify testing
Expand Down
31 changes: 29 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@
from pydantic_core import SchemaValidator, core_schema
from typing_extensions import ParamSpec, TypeIs, TypeVar

from pydantic_ai.messages import CustomEvent, ToolReturn

from ._griffe import doc_descriptions
from ._run_context import RunContext
from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor
from ._utils import (
check_object_json_schema,
is_async_callable,
is_async_iterator_callable,
is_model_like,
run_in_executor,
)

if TYPE_CHECKING:
from .tools import DocstringFormat, ObjectJsonSchema
Expand All @@ -41,13 +49,31 @@ class FunctionSchema:
# if not None, the function takes a single by that name (besides potentially `info`)
takes_ctx: bool
is_async: bool
is_async_iterator: bool
single_arg_name: str | None = None
positional_fields: list[str] = field(default_factory=list)
var_positional_field: str | None = None

async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any:
args, kwargs = self._call_args(args_dict, ctx)
if self.is_async:
if self.is_async_iterator:
assert ctx.event_stream is not None, (
'RunContext.event_stream needs to be set to use FunctionSchema.call with async iterators'
)

async for event_payload in self.function(*args, **kwargs):
if isinstance(event_payload, ToolReturn):
return event_payload

event = (
cast(CustomEvent, event_payload)
if isinstance(event_payload, CustomEvent)
else CustomEvent(payload=event_payload)
)
await ctx.event_stream.send(event)
# TODO (DouweM): Raise if events are yielded after ToolReturn
return None
elif self.is_async:
function = cast(Callable[[Any], Awaitable[str]], self.function)
return await function(*args, **kwargs)
else:
Expand Down Expand Up @@ -221,6 +247,7 @@ def function_schema( # noqa: C901
var_positional_field=var_positional_field,
takes_ctx=takes_ctx,
is_async=is_async_callable(function),
is_async_iterator=is_async_iterator_callable(function),
function=function,
)

Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import field
from typing import TYPE_CHECKING, Generic

from anyio.streams.memory import MemoryObjectSendStream
from opentelemetry.trace import NoOpTracer, Tracer
from typing_extensions import TypeVar

Expand Down Expand Up @@ -36,6 +37,8 @@ class RunContext(Generic[AgentDepsT]):
"""Messages exchanged in the conversation so far."""
tracer: Tracer = field(default_factory=NoOpTracer)
"""The tracer to use for tracing the run."""
event_stream: MemoryObjectSendStream[_messages.CustomEvent] | None = None
"""The event stream to use for handling custom events."""
trace_include_content: bool = False
"""Whether to include the content of the messages in the trace."""
instrumentation_version: int = DEFAULT_INSTRUMENTATION_VERSION
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe
"""Build a new tool manager for the next run step, carrying over the retries from the current run step."""
if self.ctx is not None:
if ctx.run_step == self.ctx.run_step:
# TODO (DouweM): Refactor to make sure it's always set

if ctx.event_stream and not self.ctx.event_stream:
self.ctx.event_stream = ctx.event_stream
return self

retries = {
Expand Down
7 changes: 6 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,12 @@ def is_async_callable(obj: Any) -> Any:
while isinstance(obj, functools.partial):
obj = obj.func

return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # pyright: ignore[reportFunctionMemberAccess]


def is_async_iterator_callable(obj: Any) -> bool:
"""Check if a callable is an async iterator."""
return inspect.isasyncgenfunction(obj) or (callable(obj) and inspect.isasyncgenfunction(obj.__call__)) # pyright: ignore[reportFunctionMemberAccess]


def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None:
Expand Down
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
BaseToolCallPart,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CustomEvent,
FunctionToolResultEvent,
ModelMessage,
ModelRequest,
Expand Down Expand Up @@ -431,6 +432,8 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
if isinstance(event, FunctionToolResultEvent):
async for msg in _handle_tool_result_event(stream_ctx, event):
yield msg
elif isinstance(event, CustomEvent) and isinstance(event.payload, BaseEvent):
yield event.payload


async def _handle_model_request_event( # noqa: C901
Expand Down Expand Up @@ -582,6 +585,8 @@ async def _handle_tool_result_event(
content=result.model_response_str(),
)

# TODO (DouweM): Stream `event.content` as if they were user parts?

# Now check for AG-UI events returned by the tool calls.
possible_event = result.metadata or result.content
if isinstance(possible_event, BaseEvent):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(

async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult:
name = params.name
# TODO (DouweM): RunContext.event_stream -> call event_stream_handler directly?
ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
try:
tool = (await toolset.get_tools(ctx))[name]
Expand Down
28 changes: 26 additions & 2 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from dataclasses import KW_ONLY, dataclass, field, replace
from datetime import datetime
from mimetypes import guess_type
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, cast, overload
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeAlias, cast, overload

import pydantic
import pydantic_core
from genai_prices import calc_price, types as genai_types
from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Self, deprecated
from typing_extensions import Self, TypeVar, deprecated

from . import _otel_messages, _utils
from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
Expand All @@ -23,6 +23,8 @@
if TYPE_CHECKING:
from .models.instrumented import InstrumentationSettings

EventPayloadT = TypeVar('EventPayloadT', default=Any)


AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg', 'audio/ogg', 'audio/flac', 'audio/aiff', 'audio/aac']
ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
Expand Down Expand Up @@ -1724,9 +1726,31 @@ class BuiltinToolResultEvent:
"""Event type identifier, used as a discriminator."""


@dataclass(repr=False)
class CustomEvent(Generic[EventPayloadT]):
"""An event indicating the result of a function tool call."""

payload: EventPayloadT
"""The payload of the custom event."""

_: KW_ONLY

name: str | None = None
"""The optional name of the custom event."""

id: str | None = None
"""The optional ID of the custom event."""

event_kind: Literal['custom'] = 'custom'
"""Event type identifier, used as a discriminator."""

__repr__ = _utils.dataclasses_no_defaults_repr


HandleResponseEvent = Annotated[
FunctionToolCallEvent
| FunctionToolResultEvent
| CustomEvent
| BuiltinToolCallEvent # pyright: ignore[reportDeprecated]
| BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
pydantic.Discriminator('event_kind'),
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def from_schema(
json_schema=json_schema,
takes_ctx=takes_ctx,
is_async=_utils.is_async_callable(function),
is_async_iterator=_utils.is_async_iterator_callable(function),
)

return cls(
Expand Down
81 changes: 81 additions & 0 deletions tests/test_ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,20 @@ async def send_custom() -> ToolReturn:
)


async def yield_custom() -> AsyncIterator[CustomEvent | ToolReturn]:
yield CustomEvent(
type=EventType.CUSTOM,
name='custom_event1',
value={'key1': 'value1'},
)
yield CustomEvent(
type=EventType.CUSTOM,
name='custom_event2',
value={'key2': 'value2'},
)
yield ToolReturn('Done')


def uuid_str() -> str:
"""Generate a random UUID string."""
return uuid.uuid4().hex
Expand Down Expand Up @@ -815,6 +829,73 @@ async def stream_function(
)


async def test_tool_local_yield_events() -> None:
"""Test local tool call that yields multiple events."""

async def stream_function(
messages: list[ModelMessage], agent_info: AgentInfo
) -> AsyncIterator[DeltaToolCalls | str]:
if len(messages) == 1:
# First call - make a tool call
yield {0: DeltaToolCall(name='yield_custom')}
yield {0: DeltaToolCall(json_args='{}')}
else:
# Second call - return text result
yield 'success yield_custom called'

agent = Agent(
model=FunctionModel(stream_function=stream_function),
tools=[yield_custom],
)

run_input = create_input(
UserMessage(
id='msg_1',
content='Please call yield_custom',
),
)
events = await run_and_collect_events(agent, run_input)

assert events == snapshot(
[
{
'type': 'RUN_STARTED',
'threadId': (thread_id := IsSameStr()),
'runId': (run_id := IsSameStr()),
},
{
'type': 'TOOL_CALL_START',
'toolCallId': (tool_call_id := IsSameStr()),
'toolCallName': 'yield_custom',
'parentMessageId': IsStr(),
},
{'type': 'TOOL_CALL_ARGS', 'toolCallId': tool_call_id, 'delta': '{}'},
{'type': 'TOOL_CALL_END', 'toolCallId': tool_call_id},
{'type': 'CUSTOM', 'name': 'custom_event1', 'value': {'key1': 'value1'}},
{'type': 'CUSTOM', 'name': 'custom_event2', 'value': {'key2': 'value2'}},
{
'type': 'TOOL_CALL_RESULT',
'messageId': IsStr(),
'toolCallId': tool_call_id,
'content': 'Done',
'role': 'tool',
},
{'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'},
{
'type': 'TEXT_MESSAGE_CONTENT',
'messageId': message_id,
'delta': 'success yield_custom called',
},
{'type': 'TEXT_MESSAGE_END', 'messageId': message_id},
{
'type': 'RUN_FINISHED',
'threadId': thread_id,
'runId': run_id,
},
]
)


async def test_tool_local_parts() -> None:
"""Test local tool call with streaming/parts."""

Expand Down
Loading
Loading