diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 1054cef630..71aef65740 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -42,6 +42,7 @@ BinaryImage, BuiltinToolCallPart, BuiltinToolReturnPart, + CustomEvent, DocumentFormat, DocumentMediaType, DocumentUrl, @@ -68,6 +69,7 @@ PartEndEvent, PartStartEvent, RetryPromptPart, + Return, SystemPromptPart, TextPart, TextPartDelta, @@ -141,6 +143,7 @@ 'BinaryContent', 'BuiltinToolCallPart', 'BuiltinToolReturnPart', + 'CustomEvent', 'DocumentFormat', 'DocumentMediaType', 'DocumentUrl', @@ -168,6 +171,7 @@ 'PartEndEvent', 'PartStartEvent', 'RetryPromptPart', + 'Return', 'SystemPromptPart', 'TextPart', 'TextPartDelta', diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index c167521079..bb1e97005f 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -12,6 +12,8 @@ from dataclasses import field, replace from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast +import anyio +from anyio.streams.memory import MemoryObjectSendStream from opentelemetry.trace import Tracer from typing_extensions import TypeVar, assert_never @@ -439,7 +441,7 @@ async def stream( _output_schema=ctx.deps.output_schema, _model_request_parameters=model_request_parameters, _output_validators=ctx.deps.output_validators, - _run_ctx=build_run_context(ctx), + _run_ctx=run_context, _usage_limits=ctx.deps.usage_limits, _tool_manager=ctx.deps.tool_manager, ) @@ -561,115 +563,142 @@ async def stream( async def _run_stream( # noqa: C901 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> AsyncIterator[_messages.HandleResponseEvent]: + # Ensure that the stream is only run once if self._events_iterator is None: - # Ensure that the stream is only run once + run_context = build_run_context(ctx) - output_schema = ctx.deps.output_schema + # This will raise errors for any tool name conflicts + ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context) + tool_manager = ctx.deps.tool_manager async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa: C901 - if not self.model_response.parts: - # we got an empty response. - # this sometimes happens with anthropic (and perhaps other models) - # when the model has already returned text along side tool calls - if text_processor := output_schema.text_processor: # pragma: no branch - # in this scenario, if text responses are allowed, we return text from the most recent model - # response, if any - for message in reversed(ctx.state.message_history): - if isinstance(message, _messages.ModelResponse): - text = '' - for part in message.parts: - if isinstance(part, _messages.TextPart): - text += part.content - elif isinstance(part, _messages.BuiltinToolCallPart): - # Text parts before a built-in tool call are essentially thoughts, - # not part of the final result output, so we reset the accumulated text - text = '' # pragma: no cover - if text: - try: - self._next_node = await self._handle_text_response(ctx, text, text_processor) - return - except ToolRetryError: - # If the text from the preview response was invalid, ignore it. - pass - - # Go back to the model request node with an empty request, which means we'll essentially - # resubmit the most recent request that resulted in an empty response, - # as the empty response and request will not create any items in the API payload, - # in the hope the model will return a non-empty response this time. - ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings) - run_context = build_run_context(ctx) - instructions = await ctx.deps.get_instructions(run_context) - self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( - _messages.ModelRequest(parts=[], instructions=instructions) - ) - return - - text = '' - tool_calls: list[_messages.ToolCallPart] = [] - files: list[_messages.BinaryContent] = [] - - for part in self.model_response.parts: - if isinstance(part, _messages.TextPart): - text += part.content - elif isinstance(part, _messages.ToolCallPart): - tool_calls.append(part) - elif isinstance(part, _messages.FilePart): - files.append(part.content) - elif isinstance(part, _messages.BuiltinToolCallPart): - # Text parts before a built-in tool call are essentially thoughts, - # not part of the final result output, so we reset the accumulated text - text = '' - yield _messages.BuiltinToolCallEvent(part) # pyright: ignore[reportDeprecated] - elif isinstance(part, _messages.BuiltinToolReturnPart): - yield _messages.BuiltinToolResultEvent(part) # pyright: ignore[reportDeprecated] - elif isinstance(part, _messages.ThinkingPart): - pass - else: - assert_never(part) - - try: - # At the moment, we prioritize at least executing tool calls if they are present. - # In the future, we'd consider making this configurable at the agent or run level. - # This accounts for cases like anthropic returns that might contain a text response - # and a tool call response, where the text response just indicates the tool call will happen. - alternatives: list[str] = [] - if tool_calls: - async for event in self._handle_tool_calls(ctx, tool_calls): - yield event - return - elif output_schema.toolset: - alternatives.append('include your response in a tool call') - else: - alternatives.append('call a tool') - - if output_schema.allows_image: - if image := next((file for file in files if isinstance(file, _messages.BinaryImage)), None): - self._next_node = await self._handle_image_response(ctx, image) + send_stream, receive_stream = anyio.create_memory_object_stream[_messages.HandleResponseEvent]() + + async def _run(): # noqa: C901 + async with send_stream: + assert tool_manager.ctx is not None, 'ToolManager.ctx needs to be set' + tool_manager.ctx.event_stream = send_stream + + output_schema = ctx.deps.output_schema + if not self.model_response.parts: + # we got an empty response. + # this sometimes happens with anthropic (and perhaps other models) + # when the model has already returned text along side tool calls + if text_processor := output_schema.text_processor: # pragma: no branch + # in this scenario, if text responses are allowed, we return text from the most recent model + # response, if any + for message in reversed(ctx.state.message_history): + if isinstance(message, _messages.ModelResponse): + text = '' + for part in message.parts: + if isinstance(part, _messages.TextPart): + text += part.content + elif isinstance(part, _messages.BuiltinToolCallPart): + # Text parts before a built-in tool call are essentially thoughts, + # not part of the final result output, so we reset the accumulated text + text = '' # pragma: no cover + if text: + try: + self._next_node = await self._handle_text_response( + ctx, text, text_processor, send_stream + ) + return + except ToolRetryError: + # If the text from the preview response was invalid, ignore it. + pass + + # Go back to the model request node with an empty request, which means we'll essentially + # resubmit the most recent request that resulted in an empty response, + # as the empty response and request will not create any items in the API payload, + # in the hope the model will return a non-empty response this time. + ctx.state.increment_retries( + ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings + ) + run_context = build_run_context(ctx) + instructions = await ctx.deps.get_instructions(run_context) + self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( + _messages.ModelRequest(parts=[], instructions=instructions) + ) return - alternatives.append('return an image') - if text_processor := output_schema.text_processor: - if text: - self._next_node = await self._handle_text_response(ctx, text, text_processor) - return - alternatives.insert(0, 'return text') + text = '' + tool_calls: list[_messages.ToolCallPart] = [] + files: list[_messages.BinaryContent] = [] + + for part in self.model_response.parts: + if isinstance(part, _messages.TextPart): + text += part.content + elif isinstance(part, _messages.ToolCallPart): + tool_calls.append(part) + elif isinstance(part, _messages.FilePart): + files.append(part.content) + elif isinstance(part, _messages.BuiltinToolCallPart): + # Text parts before a built-in tool call are essentially thoughts, + # not part of the final result output, so we reset the accumulated text + text = '' + await send_stream.send(_messages.BuiltinToolCallEvent(part)) # pyright: ignore[reportDeprecated] + elif isinstance(part, _messages.BuiltinToolReturnPart): + await send_stream.send(_messages.BuiltinToolResultEvent(part)) # pyright: ignore[reportDeprecated] + elif isinstance(part, _messages.ThinkingPart): + pass + else: + assert_never(part) + + try: + # At the moment, we prioritize at least executing tool calls if they are present. + # In the future, we'd consider making this configurable at the agent or run level. + # This accounts for cases like anthropic returns that might contain a text response + # and a tool call response, where the text response just indicates the tool call will happen. + alternatives: list[str] = [] + if tool_calls: + async for event in self._handle_tool_calls(ctx, tool_calls): + await send_stream.send(event) + return + elif output_schema.toolset: + alternatives.append('include your response in a tool call') + else: + alternatives.append('call a tool') - # handle responses with only parts that don't constitute output. - # This can happen with models that support thinking mode when they don't provide - # actionable output alongside their thinking content. so we tell the model to try again. - m = _messages.RetryPromptPart( - content=f'Please {" or ".join(alternatives)}.', - ) - raise ToolRetryError(m) - except ToolRetryError as e: - ctx.state.increment_retries( - ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings - ) - run_context = build_run_context(ctx) - instructions = await ctx.deps.get_instructions(run_context) - self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( - _messages.ModelRequest(parts=[e.tool_retry], instructions=instructions) - ) + if output_schema.allows_image: + if image := next( + (file for file in files if isinstance(file, _messages.BinaryImage)), None + ): + self._next_node = await self._handle_image_response(ctx, image) + return + alternatives.append('return an image') + + if text_processor := output_schema.text_processor: + if text: + self._next_node = await self._handle_text_response( + ctx, text, text_processor, send_stream + ) + return + alternatives.insert(0, 'return text') + + # handle responses with only parts that don't constitute output. + # This can happen with models that support thinking mode when they don't provide + # actionable output alongside their thinking content. so we tell the model to try again. + m = _messages.RetryPromptPart( + content=f'Please {" or ".join(alternatives)}.', + ) + raise ToolRetryError(m) + except ToolRetryError as e: + ctx.state.increment_retries( + ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings + ) + run_context = build_run_context(ctx) + instructions = await ctx.deps.get_instructions(run_context) + self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( + _messages.ModelRequest(parts=[e.tool_retry], instructions=instructions) + ) + + task = asyncio.create_task(_run()) + + async with receive_stream: + async for message in receive_stream: + yield message + + await task self._events_iterator = _run_stream() @@ -683,9 +712,6 @@ async def _handle_tool_calls( ) -> AsyncIterator[_messages.HandleResponseEvent]: run_context = build_run_context(ctx) - # This will raise errors for any tool name conflicts - ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context) - output_parts: list[_messages.ModelRequestPart] = [] output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1) @@ -701,8 +727,7 @@ async def _handle_tool_calls( yield event if output_final_result: - final_result = output_final_result[0] - self._next_node = self._handle_final_result(ctx, final_result, output_parts) + self._next_node = self._handle_final_result(ctx, output_final_result[0], output_parts) else: instructions = await ctx.deps.get_instructions(run_context) self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( @@ -714,8 +739,10 @@ async def _handle_text_response( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], text: str, text_processor: _output.BaseOutputProcessor[NodeRunEndT], + event_stream: MemoryObjectSendStream[_messages.CustomEvent], ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: run_context = build_run_context(ctx) + run_context = replace(run_context, event_stream=event_stream) result_data = await text_processor.process(text, run_context) @@ -1061,7 +1088,8 @@ async def _call_tool( except ToolRetryError as e: return e.tool_retry, None - if isinstance(tool_result, _messages.ToolReturn): + tool_return: _messages.Return | None = None + if isinstance(tool_result, _messages.Return): tool_return = tool_result else: result_is_list = isinstance(tool_result, list) @@ -1072,8 +1100,8 @@ async def _call_tool( for content in contents: if isinstance(content, _messages.ToolReturn): raise exceptions.UserError( - f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. ' - f'`ToolReturn` should be used directly.' + f'The return value of tool {tool_call.tool_name!r} contains invalid nested `Return` objects. ' + f'`Return` should be used directly.' ) elif isinstance(content, _messages.MultiModalContent): identifier = content.identifier @@ -1105,10 +1133,13 @@ async def _call_tool( tool_name=tool_call.tool_name, tool_call_id=tool_call.tool_call_id, content=tool_return.return_value, # type: ignore - metadata=tool_return.metadata, ) - return return_part, tool_return.content or None + if isinstance(tool_return, _messages.ToolReturn): + return_part.metadata = tool_return.metadata + return return_part, tool_return.content or None + else: + return return_part, None @dataclasses.dataclass diff --git a/pydantic_ai_slim/pydantic_ai/_function_schema.py b/pydantic_ai_slim/pydantic_ai/_function_schema.py index 2b8270f322..6d91e474f6 100644 --- a/pydantic_ai_slim/pydantic_ai/_function_schema.py +++ b/pydantic_ai_slim/pydantic_ai/_function_schema.py @@ -6,7 +6,7 @@ from __future__ import annotations as _annotations from collections.abc import Awaitable, Callable -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from inspect import Parameter, signature from typing import TYPE_CHECKING, Any, Concatenate, cast, get_origin @@ -19,9 +19,17 @@ from pydantic_core import SchemaValidator, core_schema from typing_extensions import ParamSpec, TypeIs, TypeVar +from pydantic_ai.messages import CustomEvent, Return + from ._griffe import doc_descriptions from ._run_context import RunContext -from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor +from ._utils import ( + check_object_json_schema, + is_async_callable, + is_async_iterator_callable, + is_model_like, + run_in_executor, +) if TYPE_CHECKING: from .tools import DocstringFormat, ObjectJsonSchema @@ -41,13 +49,46 @@ class FunctionSchema: # if not None, the function takes a single by that name (besides potentially `info`) takes_ctx: bool is_async: bool + is_async_iterator: bool single_arg_name: str | None = None positional_fields: list[str] = field(default_factory=list) var_positional_field: str | None = None async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any: args, kwargs = self._call_args(args_dict, ctx) - if self.is_async: + if self.is_async_iterator: + return_value: Return | None = None + async for event_data in self.function(*args, **kwargs): + if return_value is not None: + from .exceptions import UserError + + raise UserError('Return value must be the last value yielded by the function') + + if isinstance(event_data, Return): + return_value = cast(Return, event_data) + continue + + # If there's no event stream, we're being called from inside `agent.run_stream()` or `AgentStream.get_output()`, + # after event streaming has completed and final result streaming has begun, so there's nowhere to yield custom events to. + # We could consider storing the yielded events somewhere and letting them be accessed after the fact as a list. + if ctx.event_stream is not None: + if isinstance(event_data, CustomEvent): + # TODO (DouweM): Whgat if this is coming from a nested agent run? + # Should rewrap! If there's a tool call ID? + # But what if NativeOutput(output_func)? No tool call ID... + # DeferredToolCalls etc wouldn't work either. Only support yielding output function with ToolOutput()? + # Handoffs; think about custom event T transformer; HandoffEvent with agent_name? + # CustomEvent with tool_call_id, nested deeply... + # How to match tool_call_id to agent name? Through ToolCallPart.tool_name? Start handoff event with metadata? + event = cast(CustomEvent, event_data) + if ctx.tool_call_id: + event = replace(event, tool_call_id=ctx.tool_call_id) + else: + event = CustomEvent(data=event_data, tool_call_id=ctx.tool_call_id) + await ctx.event_stream.send(event) + + return return_value + elif self.is_async: function = cast(Callable[[Any], Awaitable[str]], self.function) return await function(*args, **kwargs) else: @@ -221,6 +262,7 @@ def function_schema( # noqa: C901 var_positional_field=var_positional_field, takes_ctx=takes_ctx, is_async=is_async_callable(function), + is_async_iterator=is_async_iterator_callable(function), function=function, ) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index ebb737a1cf..eb21c3374e 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload +from typing import Any, Generic, Literal, cast, overload from pydantic import Json, TypeAdapter, ValidationError from pydantic_core import SchemaValidator, to_json @@ -34,9 +34,6 @@ from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition from .toolsets.abstract import AbstractToolset, ToolsetTool -if TYPE_CHECKING: - pass - T = TypeVar('T') """An invariant TypeVar.""" OutputDataT_inv = TypeVar('OutputDataT_inv', default=str) @@ -133,6 +130,8 @@ async def execute_traced_output_function( ) as span: try: output = await function_schema.call(args, run_context) + if isinstance(output, _messages.Return): + output = cast(_messages.Return[Any], output).return_value except ModelRetry as r: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -229,7 +228,7 @@ def allows_text(self) -> bool: @classmethod def build( # noqa: C901 cls, - output_spec: OutputSpec[OutputDataT], + output_spec: OutputSpec[OutputDataT, Any], *, name: str | None = None, description: str | None = None, diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index 1848c42eb1..e19ec705ad 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -5,6 +5,7 @@ from dataclasses import field from typing import TYPE_CHECKING, Generic +from anyio.streams.memory import MemoryObjectSendStream from opentelemetry.trace import NoOpTracer, Tracer from typing_extensions import TypeVar @@ -40,6 +41,9 @@ class RunContext(Generic[RunContextAgentDepsT]): """Messages exchanged in the conversation so far.""" tracer: Tracer = field(default_factory=NoOpTracer) """The tracer to use for tracing the run.""" + # TODO (DouweM): Generic param? + event_stream: MemoryObjectSendStream[_messages.CustomEvent] | None = None + """The event stream to use for handling custom events.""" trace_include_content: bool = False """Whether to include the content of the messages in the trace.""" instrumentation_version: int = DEFAULT_INSTRUMENTATION_VERSION diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 4f6deb2f2e..9a6672af02 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -375,7 +375,12 @@ def is_async_callable(obj: Any) -> Any: while isinstance(obj, functools.partial): obj = obj.func - return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore + return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # pyright: ignore[reportFunctionMemberAccess] + + +def is_async_iterator_callable(obj: Any) -> bool: + """Check if a callable is an async iterator.""" + return inspect.isasyncgenfunction(obj) or (callable(obj) and inspect.isasyncgenfunction(obj.__call__)) # pyright: ignore[reportFunctionMemberAccess] def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None: diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index fb751877f5..5e777aff45 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -13,7 +13,7 @@ from . import DeferredToolResults from .agent import AbstractAgent -from .messages import ModelMessage +from .messages import CustomEventDataT, ModelMessage from .models import KnownModelName, Model from .output import OutputSpec from .settings import ModelSettings @@ -49,10 +49,10 @@ async def handle_ag_ui_request( - agent: AbstractAgent[AgentDepsT, Any], + agent: AbstractAgent[AgentDepsT, Any, CustomEventDataT], request: Request, *, - output_type: OutputSpec[Any] | None = None, + output_type: OutputSpec[Any, CustomEventDataT] | None = None, message_history: Sequence[ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: Model | KnownModelName | str | None = None, @@ -105,11 +105,11 @@ async def handle_ag_ui_request( def run_ag_ui( - agent: AbstractAgent[AgentDepsT, Any], + agent: AbstractAgent[AgentDepsT, Any, CustomEventDataT], run_input: RunAgentInput, accept: str = SSE_CONTENT_TYPE, *, - output_type: OutputSpec[Any] | None = None, + output_type: OutputSpec[Any, CustomEventDataT] | None = None, message_history: Sequence[ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: Model | KnownModelName | str | None = None, diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 5bcfa6baae..a03878952f 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -3,12 +3,11 @@ import dataclasses import inspect import json -import warnings from asyncio import Lock from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar -from typing import TYPE_CHECKING, Any, ClassVar, overload +from typing import Any, ClassVar, overload from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema @@ -37,6 +36,7 @@ from .._output import OutputToolset from .._tool_manager import ToolManager from ..builtin_tools import AbstractBuiltinTool +from ..messages import CustomEventDataT from ..models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model from ..output import OutputDataT, OutputSpec from ..run import AgentRun, AgentRunResult @@ -66,9 +66,6 @@ from .abstract import AbstractAgent, EventStreamHandler, Instructions, RunOutputDataT from .wrapper import WrapperAgent -if TYPE_CHECKING: - from ..mcp import MCPServer - __all__ = ( 'Agent', 'AgentRun', @@ -91,7 +88,7 @@ @dataclasses.dataclass(init=False) -class Agent(AbstractAgent[AgentDepsT, OutputDataT]): +class Agent(AbstractAgent[AgentDepsT, OutputDataT, CustomEventDataT]): """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT] @@ -124,7 +121,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): be merged with this value, with the runtime argument taking priority. """ - _output_type: OutputSpec[OutputDataT] + _output_type: OutputSpec[OutputDataT, CustomEventDataT] instrument: InstrumentationSettings | bool | None """Options to automatically instrument with OpenTelemetry.""" @@ -148,68 +145,17 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): _max_result_retries: int = dataclasses.field(repr=False) _max_tool_retries: int = dataclasses.field(repr=False) - _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) + _event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = dataclasses.field(repr=False) _enter_lock: Lock = dataclasses.field(repr=False) _entered_count: int = dataclasses.field(repr=False) _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) - @overload - def __init__( - self, - model: models.Model | models.KnownModelName | str | None = None, - *, - output_type: OutputSpec[OutputDataT] = str, - instructions: Instructions[AgentDepsT] = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, - model_settings: ModelSettings | None = None, - retries: int = 1, - output_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - builtin_tools: Sequence[AbstractBuiltinTool] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - toolsets: Sequence[AbstractToolset[AgentDepsT] | ToolsetFunc[AgentDepsT]] | None = None, - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, - ) -> None: ... - - @overload - @deprecated('`mcp_servers` is deprecated, use `toolsets` instead.') - def __init__( - self, - model: models.Model | models.KnownModelName | str | None = None, - *, - output_type: OutputSpec[OutputDataT] = str, - instructions: Instructions[AgentDepsT] = None, - system_prompt: str | Sequence[str] = (), - deps_type: type[AgentDepsT] = NoneType, - name: str | None = None, - model_settings: ModelSettings | None = None, - retries: int = 1, - output_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), - builtin_tools: Sequence[AbstractBuiltinTool] = (), - prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), - defer_model_check: bool = False, - end_strategy: EndStrategy = 'early', - instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, - ) -> None: ... - def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, - output_type: OutputSpec[OutputDataT] = str, + output_type: OutputSpec[OutputDataT, CustomEventDataT] = str, instructions: Instructions[AgentDepsT] = None, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDepsT] = NoneType, @@ -217,7 +163,8 @@ def __init__( model_settings: ModelSettings | None = None, retries: int = 1, output_retries: int | None = None, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + tools: Sequence[Tool[AgentDepsT, CustomEventDataT] | ToolFuncEither[AgentDepsT, ..., CustomEventDataT]] + | None = None, builtin_tools: Sequence[AbstractBuiltinTool] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, @@ -226,8 +173,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, - **_deprecated_kwargs: Any, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ): """Create an agent. @@ -294,14 +240,6 @@ def __init__( self.instrument = instrument self._deps_type = deps_type - if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None): - if toolsets is not None: # pragma: no cover - raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.') - warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning) - toolsets = mcp_servers - - _utils.validate_empty_kwargs(_deprecated_kwargs) - self._output_schema = _output.OutputSchema[OutputDataT].build(output_type) self._output_validators = [] @@ -324,7 +262,7 @@ def __init__( self._output_toolset.max_retries = self._max_result_retries self._function_toolset = _AgentFunctionToolset( - tools, max_retries=self._max_tool_retries, output_schema=self._output_schema + tools or [], max_retries=self._max_tool_retries, output_schema=self._output_schema ) self._dynamic_toolsets = [ DynamicToolset[AgentDepsT](toolset_func=toolset) @@ -344,7 +282,9 @@ def __init__( '_override_toolsets', default=None ) self._override_tools: ContextVar[ - _utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]] + _utils.Option[ + Sequence[Tool[AgentDepsT, CustomEventDataT] | ToolFuncEither[AgentDepsT, ..., CustomEventDataT]] + ] ] = ContextVar('_override_tools', default=None) self._override_instructions: ContextVar[ _utils.Option[list[str | _system_prompt.SystemPromptFunc[AgentDepsT]]] @@ -392,12 +332,12 @@ def deps_type(self) -> type: return self._deps_type @property - def output_type(self) -> OutputSpec[OutputDataT]: + def output_type(self) -> OutputSpec[OutputDataT, CustomEventDataT]: """The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`.""" return self._output_type @property - def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT, CustomEventDataT] | None: """Optional handler for events from the model's streaming response and the agent's execution of tools.""" return self._event_stream_handler @@ -428,7 +368,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -447,7 +387,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[Any] | None = None, + output_type: OutputSpec[Any, Any] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -736,7 +676,8 @@ def override( deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT, CustomEventDataT] | ToolFuncEither[AgentDepsT, ..., CustomEventDataT]] + | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -1003,7 +944,9 @@ async def output_validator_deps(ctx: RunContext[str], data: str) -> str: return func @overload - def tool(self, func: ToolFuncContext[AgentDepsT, ToolParams], /) -> ToolFuncContext[AgentDepsT, ToolParams]: ... + def tool( + self, func: ToolFuncContext[AgentDepsT, ToolParams, CustomEventDataT], / + ) -> ToolFuncContext[AgentDepsT, ToolParams, CustomEventDataT]: ... @overload def tool( @@ -1021,11 +964,14 @@ def tool( sequential: bool = False, requires_approval: bool = False, metadata: dict[str, Any] | None = None, - ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ... + ) -> Callable[ + [ToolFuncContext[AgentDepsT, ToolParams, CustomEventDataT]], + ToolFuncContext[AgentDepsT, ToolParams, CustomEventDataT], + ]: ... def tool( self, - func: ToolFuncContext[AgentDepsT, ToolParams] | None = None, + func: ToolFuncContext[AgentDepsT, ToolParams, CustomEventDataT] | None = None, /, *, name: str | None = None, @@ -1091,8 +1037,8 @@ async def spam(ctx: RunContext[str], y: float) -> float: """ def tool_decorator( - func_: ToolFuncContext[AgentDepsT, ToolParams], - ) -> ToolFuncContext[AgentDepsT, ToolParams]: + func_: ToolFuncContext[AgentDepsT, ToolParams, CustomEventDataT], + ) -> ToolFuncContext[AgentDepsT, ToolParams, CustomEventDataT]: # noinspection PyTypeChecker self._function_toolset.add_function( func_, @@ -1114,7 +1060,9 @@ def tool_decorator( return tool_decorator if func is None else tool_decorator(func) @overload - def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ... + def tool_plain( + self, func: ToolFuncPlain[ToolParams, CustomEventDataT], / + ) -> ToolFuncPlain[ToolParams, CustomEventDataT]: ... @overload def tool_plain( @@ -1132,11 +1080,11 @@ def tool_plain( sequential: bool = False, requires_approval: bool = False, metadata: dict[str, Any] | None = None, - ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ... + ) -> Callable[[ToolFuncPlain[ToolParams, CustomEventDataT]], ToolFuncPlain[ToolParams, CustomEventDataT]]: ... def tool_plain( self, - func: ToolFuncPlain[ToolParams] | None = None, + func: ToolFuncPlain[ToolParams, CustomEventDataT] | None = None, /, *, name: str | None = None, @@ -1201,7 +1149,9 @@ async def spam(ctx: RunContext[str]) -> float: metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization. """ - def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]: + def tool_decorator( + func_: ToolFuncPlain[ToolParams, CustomEventDataT], + ) -> ToolFuncPlain[ToolParams, CustomEventDataT]: # noinspection PyTypeChecker self._function_toolset.add_function( func_, @@ -1413,10 +1363,10 @@ def _prepare_output_schema(self, output_type: None) -> _output.OutputSchema[Outp @overload def _prepare_output_schema( - self, output_type: OutputSpec[RunOutputDataT] + self, output_type: OutputSpec[RunOutputDataT, CustomEventDataT] ) -> _output.OutputSchema[RunOutputDataT]: ... - def _prepare_output_schema(self, output_type: OutputSpec[Any] | None) -> _output.OutputSchema[Any]: + def _prepare_output_schema(self, output_type: OutputSpec[Any, Any] | None) -> _output.OutputSchema[Any]: if output_type is not None: if self._output_validators: raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators') @@ -1498,7 +1448,7 @@ class _AgentFunctionToolset(FunctionToolset[AgentDepsT]): def __init__( self, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], + tools: Sequence[Tool[AgentDepsT, CustomEventDataT] | ToolFuncEither[AgentDepsT, ..., CustomEventDataT]] = [], *, max_retries: int = 1, id: str | None = None, @@ -1515,7 +1465,7 @@ def id(self) -> str: def label(self) -> str: return 'the agent' - def add_tool(self, tool: Tool[AgentDepsT]) -> None: + def add_tool(self, tool: Tool[AgentDepsT, CustomEventDataT]) -> None: if tool.requires_approval and not self.output_schema.allows_deferred_tools: raise exceptions.UserError( 'To use tools that require approval, add `DeferredToolRequests` to the list of output types for this agent.' diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index fa5846a31d..cb6e206349 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -6,10 +6,10 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator, Mapping, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager from types import FrameType -from typing import TYPE_CHECKING, Any, Generic, TypeAlias, cast, overload +from typing import TYPE_CHECKING, Any, Generic, cast, overload import anyio -from typing_extensions import Self, TypeIs, TypeVar +from typing_extensions import Self, TypeAliasType, TypeIs, TypeVar from pydantic_graph import End @@ -25,6 +25,7 @@ ) from .._tool_manager import ToolManager from ..builtin_tools import AbstractBuiltinTool +from ..messages import CustomEventDataT from ..output import OutputDataT, OutputSpec from ..result import AgentStream, FinalResult, StreamedRunResult from ..run import AgentRun, AgentRunResult, AgentRunResultEvent @@ -57,9 +58,11 @@ RunOutputDataT = TypeVar('RunOutputDataT') """Type variable for the result data of a run where `output_type` was customized on the run call.""" -EventStreamHandler: TypeAlias = Callable[ - [RunContext[AgentDepsT], AsyncIterable[_messages.AgentStreamEvent]], Awaitable[None] -] +EventStreamHandler = TypeAliasType( + 'EventStreamHandler', + Callable[[RunContext[AgentDepsT], AsyncIterable[_messages.AgentStreamEvent[CustomEventDataT]]], Awaitable[None]], + type_params=(AgentDepsT, CustomEventDataT), +) """A function that receives agent [`RunContext`][pydantic_ai.tools.RunContext] and an async iterable of events from the model's streaming response and the agent's execution of tools.""" @@ -71,7 +74,7 @@ ) -class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC): +class AbstractAgent(Generic[AgentDepsT, OutputDataT, CustomEventDataT], ABC): """Abstract superclass for [`Agent`][pydantic_ai.agent.Agent], [`WrapperAgent`][pydantic_ai.agent.WrapperAgent], and your own custom agent implementations.""" @property @@ -103,13 +106,13 @@ def deps_type(self) -> type: @property @abstractmethod - def output_type(self) -> OutputSpec[OutputDataT]: + def output_type(self) -> OutputSpec[OutputDataT, CustomEventDataT]: """The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`.""" raise NotImplementedError @property @abstractmethod - def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT, CustomEventDataT] | None: """Optional handler for events from the model's streaming response and the agent's execution of tools.""" raise NotImplementedError @@ -139,7 +142,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -147,7 +150,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -159,14 +162,14 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -178,7 +181,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -263,7 +266,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -271,7 +274,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -283,14 +286,14 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -302,7 +305,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -379,7 +382,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload @@ -387,7 +390,7 @@ def run_stream( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -399,7 +402,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -407,7 +410,7 @@ async def run_stream( # noqa C901 self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -419,7 +422,7 @@ async def run_stream( # noqa C901 infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async streaming mode. @@ -611,7 +614,7 @@ def run_stream_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> result.StreamedRunResultSync[AgentDepsT, OutputDataT]: ... @overload @@ -619,7 +622,7 @@ def run_stream_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -630,14 +633,14 @@ def run_stream_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> result.StreamedRunResultSync[AgentDepsT, RunOutputDataT]: ... def run_stream_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -648,7 +651,7 @@ def run_stream_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> result.StreamedRunResultSync[AgentDepsT, Any]: """Run the agent with a user prompt in sync streaming mode. @@ -739,14 +742,14 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ... + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[OutputDataT]]: ... @overload def run_stream_events( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -758,13 +761,13 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ... + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[RunOutputDataT]]: ... def run_stream_events( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -776,7 +779,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[Any]]: """Run the agent with a user prompt in async mode and stream events from the run. This is a convenience method that wraps [`self.run`][pydantic_ai.agent.AbstractAgent.run] and @@ -856,7 +859,7 @@ async def _run_stream_events( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -867,13 +870,13 @@ async def _run_stream_events( usage: _usage.RunUsage | None = None, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[Any]]: send_stream, receive_stream = anyio.create_memory_object_stream[ - _messages.AgentStreamEvent | AgentRunResultEvent[Any] + _messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[Any] ]() async def event_stream_handler( - _: RunContext[AgentDepsT], events: AsyncIterable[_messages.AgentStreamEvent] + _: RunContext[AgentDepsT], events: AsyncIterable[_messages.AgentStreamEvent[CustomEventDataT]] ) -> None: async for event in events: await send_stream.send(event) @@ -930,7 +933,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -950,7 +953,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -1052,7 +1055,8 @@ def override( deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT, CustomEventDataT] | ToolFuncEither[AgentDepsT, ..., CustomEventDataT]] + | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -1150,7 +1154,7 @@ def to_ag_ui( self, *, # Agent.iter parameters - output_type: OutputSpec[OutputDataT] | None = None, + output_type: OutputSpec[OutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index fcf7826f13..9a73784aa7 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -11,6 +11,7 @@ usage as _usage, ) from ..builtin_tools import AbstractBuiltinTool +from ..messages import CustomEventDataT from ..output import OutputDataT, OutputSpec from ..run import AgentRun from ..settings import ModelSettings @@ -24,13 +25,13 @@ from .abstract import AbstractAgent, EventStreamHandler, Instructions, RunOutputDataT -class WrapperAgent(AbstractAgent[AgentDepsT, OutputDataT]): +class WrapperAgent(AbstractAgent[AgentDepsT, OutputDataT, CustomEventDataT]): """Agent which wraps another agent. Does nothing on its own, used as a base class. """ - def __init__(self, wrapped: AbstractAgent[AgentDepsT, OutputDataT]): + def __init__(self, wrapped: AbstractAgent[AgentDepsT, OutputDataT, CustomEventDataT]): self.wrapped = wrapped @property @@ -50,11 +51,11 @@ def deps_type(self) -> type: return self.wrapped.deps_type @property - def output_type(self) -> OutputSpec[OutputDataT]: + def output_type(self) -> OutputSpec[OutputDataT, CustomEventDataT]: return self.wrapped.output_type @property - def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT, CustomEventDataT] | None: return self.wrapped.event_stream_handler @property @@ -91,7 +92,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -110,7 +111,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -225,7 +226,8 @@ def override( deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT, CustomEventDataT] | ToolFuncEither[AgentDepsT, ..., CustomEventDataT]] + | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py index 42aec0bd83..243257aee5 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py @@ -19,6 +19,7 @@ from pydantic_ai.agent.abstract import Instructions, RunOutputDataT from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import CustomEventDataT from pydantic_ai.models import Model from pydantic_ai.output import OutputDataT, OutputSpec from pydantic_ai.result import StreamedRunResult @@ -36,13 +37,13 @@ @DBOS.dbos_class() -class DBOSAgent(WrapperAgent[AgentDepsT, OutputDataT], DBOSConfiguredInstance): +class DBOSAgent(WrapperAgent[AgentDepsT, OutputDataT, CustomEventDataT], DBOSConfiguredInstance): def __init__( self, - wrapped: AbstractAgent[AgentDepsT, OutputDataT], + wrapped: AbstractAgent[AgentDepsT, OutputDataT, CustomEventDataT], *, name: str | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, mcp_step_config: StepConfig | None = None, model_step_config: StepConfig | None = None, ): @@ -112,7 +113,7 @@ def dbosify_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[Age async def wrapped_run_workflow( user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -124,7 +125,7 @@ async def wrapped_run_workflow( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: with self._dbos_overrides(): @@ -153,7 +154,7 @@ async def wrapped_run_workflow( def wrapped_run_sync_workflow( user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -165,7 +166,7 @@ def wrapped_run_sync_workflow( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: with self._dbos_overrides(): @@ -204,7 +205,7 @@ def model(self) -> Model: return self._model @property - def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT, CustomEventDataT] | None: handler = self._event_stream_handler or super().event_stream_handler if handler is None: return None @@ -215,12 +216,12 @@ def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: return handler async def _call_event_stream_handler_in_workflow( - self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[_messages.AgentStreamEvent] + self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[_messages.AgentStreamEvent[CustomEventDataT]] ) -> None: handler = self._event_stream_handler or super().event_stream_handler assert handler is not None - async def streamed_response(event: _messages.AgentStreamEvent): + async def streamed_response(event: _messages.AgentStreamEvent[CustomEventDataT]): yield event async for event in stream: @@ -257,7 +258,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -265,7 +266,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -277,14 +278,14 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -296,7 +297,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -371,7 +372,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -379,7 +380,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -391,14 +392,14 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -410,7 +411,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -484,7 +485,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload @@ -492,7 +493,7 @@ def run_stream( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -504,7 +505,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -512,7 +513,7 @@ async def run_stream( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -524,7 +525,7 @@ async def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -603,14 +604,14 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ... + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[OutputDataT]]: ... @overload def run_stream_events( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -622,13 +623,13 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ... + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[RunOutputDataT]]: ... def run_stream_events( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -640,7 +641,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[Any]]: """Run the agent with a user prompt in async mode and stream events from the run. This is a convenience method that wraps [`self.run`][pydantic_ai.agent.AbstractAgent.run] and @@ -725,7 +726,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -745,7 +746,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -868,7 +869,8 @@ def override( deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT, CustomEventDataT] | ToolFuncEither[AgentDepsT, ..., CustomEventDataT]] + | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py index 0867a60e36..dbe06a57f2 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py @@ -22,6 +22,7 @@ from pydantic_ai.agent.abstract import Instructions, RunOutputDataT from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import CustomEventDataT from pydantic_ai.models import Model from pydantic_ai.output import OutputDataT, OutputSpec from pydantic_ai.result import StreamedRunResult @@ -39,13 +40,13 @@ from ._types import TaskConfig, default_task_config -class PrefectAgent(WrapperAgent[AgentDepsT, OutputDataT]): +class PrefectAgent(WrapperAgent[AgentDepsT, OutputDataT, CustomEventDataT]): def __init__( self, - wrapped: AbstractAgent[AgentDepsT, OutputDataT], + wrapped: AbstractAgent[AgentDepsT, OutputDataT, CustomEventDataT], *, name: str | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, mcp_task_config: TaskConfig | None = None, model_task_config: TaskConfig | None = None, tool_task_config: TaskConfig | None = None, @@ -133,7 +134,7 @@ def model(self) -> Model: return self._model @property - def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT, CustomEventDataT] | None: handler = self._event_stream_handler or super().event_stream_handler if handler is None: return None @@ -144,14 +145,14 @@ def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: return handler async def _call_event_stream_handler_in_flow( - self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[_messages.AgentStreamEvent] + self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[_messages.AgentStreamEvent[CustomEventDataT]] ) -> None: handler = self._event_stream_handler or super().event_stream_handler assert handler is not None # Create a task to handle each event @task(name='Handle Stream Event', **self._event_stream_handler_task_config) - async def event_stream_handler_task(event: _messages.AgentStreamEvent) -> None: + async def event_stream_handler_task(event: _messages.AgentStreamEvent[CustomEventDataT]) -> None: async def streamed_response(): yield event @@ -188,7 +189,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -196,7 +197,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -208,14 +209,14 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -227,7 +228,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -312,7 +313,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -320,7 +321,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -332,14 +333,14 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -351,7 +352,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -438,7 +439,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload @@ -446,7 +447,7 @@ def run_stream( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -458,7 +459,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -466,7 +467,7 @@ async def run_stream( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -478,7 +479,7 @@ async def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -557,14 +558,14 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ... + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[OutputDataT]]: ... @overload def run_stream_events( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -576,13 +577,13 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ... + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[RunOutputDataT]]: ... def run_stream_events( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -594,7 +595,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[Any]]: """Run the agent with a user prompt in async mode and stream events from the run. This is a convenience method that wraps [`self.run`][pydantic_ai.agent.AbstractAgent.run] and @@ -695,7 +696,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -714,7 +715,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -834,7 +835,8 @@ def override( deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT, CustomEventDataT] | ToolFuncEither[AgentDepsT, ..., CustomEventDataT]] + | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions. diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py index b6be1e7b9d..b42818dd41 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py @@ -1,11 +1,12 @@ from __future__ import annotations +import inspect from collections.abc import AsyncIterable, AsyncIterator, Callable, Iterator, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager from contextvars import ContextVar from dataclasses import dataclass from datetime import timedelta -from typing import Any, Literal, overload +from typing import Any, Generic, Literal, TypeAlias, overload from pydantic import ConfigDict, with_config from pydantic.errors import PydanticUserError @@ -27,6 +28,7 @@ from pydantic_ai.agent.abstract import Instructions, RunOutputDataT from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import CustomEventDataT from pydantic_ai.models import Model from pydantic_ai.output import OutputDataT, OutputSpec from pydantic_ai.result import StreamedRunResult @@ -44,6 +46,50 @@ from ._toolset import TemporalWrapperToolset, temporalize_toolset +@dataclass(kw_only=True) +class TemporalizeToolsetContext(Generic[AgentDepsT]): + """Context object for `temporalize_toolset_func` functions.""" + + activity_name_prefix: str + """Prefix for Temporal activity names.""" + activity_config: ActivityConfig + """The Temporal activity config to use.""" + tool_activity_config: dict[str, ActivityConfig | Literal[False]] + """The Temporal activity config to use for specific tools identified by tool name.""" + deps_type: type[AgentDepsT] + """The type of agent's dependencies object. It needs to be serializable using Pydantic's `TypeAdapter`.""" + run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT] + """The `TemporalRunContext` (sub)class that's used to serialize and deserialize the run context.""" + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None + """The event stream handler to use for custom events yielded by tools in the toolset.""" + + +TemporalizeToolsetFunc: TypeAlias = ( + Callable[ + [ + AbstractToolset[AgentDepsT], + str, + ActivityConfig, + dict[str, ActivityConfig | Literal[False]], + type[AgentDepsT], + type[TemporalRunContext[AgentDepsT]], + ], + AbstractToolset[AgentDepsT], + ] + | Callable[ + [ + AbstractToolset[AgentDepsT], + TemporalizeToolsetContext[AgentDepsT], + ], + AbstractToolset[AgentDepsT], + ] +) +"""Type of function to use to prepare "leaf" toolsets (i.e. those that implement their own tool listing and calling) for Temporal +by wrapping them in a `TemporalWrapperToolset` that moves methods that require IO to Temporal activities. + +The function takes the toolset and context and returns the temporalized toolset.""" + + @dataclass @with_config(ConfigDict(arbitrary_types_allowed=True)) class _EventStreamHandlerParams: @@ -51,29 +97,19 @@ class _EventStreamHandlerParams: serialized_run_context: Any -class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT]): +class TemporalAgent(WrapperAgent[AgentDepsT, OutputDataT, CustomEventDataT]): def __init__( self, - wrapped: AbstractAgent[AgentDepsT, OutputDataT], + wrapped: AbstractAgent[AgentDepsT, OutputDataT, CustomEventDataT], *, name: str | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, activity_config: ActivityConfig | None = None, model_activity_config: ActivityConfig | None = None, toolset_activity_config: dict[str, ActivityConfig] | None = None, tool_activity_config: dict[str, dict[str, ActivityConfig | Literal[False]]] | None = None, run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT], - temporalize_toolset_func: Callable[ - [ - AbstractToolset[AgentDepsT], - str, - ActivityConfig, - dict[str, ActivityConfig | Literal[False]], - type[AgentDepsT], - type[TemporalRunContext[AgentDepsT]], - ], - AbstractToolset[AgentDepsT], - ] = temporalize_toolset, + temporalize_toolset_func: TemporalizeToolsetFunc[AgentDepsT] = temporalize_toolset, ): """Wrap an agent to enable it to be used inside a Temporal workflow, by automatically offloading model requests, tool calls, and MCP server communication to Temporal activities. @@ -93,9 +129,10 @@ def __init__( run_context_type: The `TemporalRunContext` subclass to use to serialize and deserialize the run context for use inside a Temporal activity. By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `retry` and `run_step` attributes will be available. To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute. - temporalize_toolset_func: Optional function to use to prepare "leaf" toolsets (i.e. those that implement their own tool listing and calling) for Temporal by wrapping them in a `TemporalWrapperToolset` that moves methods that require IO to Temporal activities. + temporalize_toolset_func: Optional function to use to prepare "leaf" toolsets (i.e. those that implement their own tool listing and calling) for Temporal + by wrapping them in a `TemporalWrapperToolset` that moves methods that require IO to Temporal activities. If not provided, only `FunctionToolset` and `MCPServer` will be prepared for Temporal. - The function takes the toolset, the activity name prefix, the toolset-specific activity config, the tool-specific activity configs and the run context type. + The function takes the toolset and context and returns the temporalized toolset. """ super().__init__(wrapped) @@ -170,14 +207,26 @@ def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset "Toolsets that are 'leaves' (i.e. those that implement their own tool listing and calling) need to have a unique `id` in order to be used with Temporal. The ID will be used to identify the toolset's activities within the workflow." ) - toolset = temporalize_toolset_func( - toolset, - activity_name_prefix, - activity_config | toolset_activity_config.get(id, {}), - tool_activity_config.get(id, {}), - self.deps_type, - self.run_context_type, + context = TemporalizeToolsetContext( + activity_name_prefix=activity_name_prefix, + activity_config=activity_config | toolset_activity_config.get(id, {}), + tool_activity_config=tool_activity_config.get(id, {}), + deps_type=self.deps_type, + run_context_type=self.run_context_type, + event_stream_handler=self.event_stream_handler, ) + if len(inspect.signature(temporalize_toolset_func).parameters) == 2: + toolset = temporalize_toolset_func(toolset, context) + else: + # Backward compatibility with the original `temporalize_toolset` function signature. + toolset = temporalize_toolset_func( + toolset, + context.activity_name_prefix, + context.activity_config, + context.tool_activity_config, + context.deps_type, + context.run_context_type, + ) if isinstance(toolset, TemporalWrapperToolset): activities.extend(toolset.temporal_activities) return toolset @@ -205,7 +254,7 @@ def model(self) -> Model: return self._model @property - def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: + def event_stream_handler(self) -> EventStreamHandler[AgentDepsT, CustomEventDataT] | None: handler = self._event_stream_handler or super().event_stream_handler if handler is None: return None @@ -215,7 +264,7 @@ def event_stream_handler(self) -> EventStreamHandler[AgentDepsT] | None: return handler async def _call_event_stream_handler_activity( - self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[_messages.AgentStreamEvent] + self, ctx: RunContext[AgentDepsT], stream: AsyncIterable[_messages.AgentStreamEvent[CustomEventDataT]] ) -> None: serialized_run_context = self.run_context_type.serialize_run_context(ctx) async for event in stream: @@ -271,7 +320,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -279,7 +328,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -291,14 +340,14 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -310,7 +359,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -391,7 +440,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -399,7 +448,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -411,14 +460,14 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -430,7 +479,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -509,7 +558,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload @@ -517,7 +566,7 @@ def run_stream( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -529,7 +578,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -537,7 +586,7 @@ async def run_stream( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -549,7 +598,7 @@ async def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -628,14 +677,14 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ... + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[OutputDataT]]: ... @overload def run_stream_events( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -647,13 +696,13 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ... + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[RunOutputDataT]]: ... def run_stream_events( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -665,7 +714,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: + ) -> AsyncIterator[_messages.AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[Any]]: """Run the agent with a user prompt in async mode and stream events from the run. This is a convenience method that wraps [`self.run`][pydantic_ai.agent.AbstractAgent.run] and @@ -767,7 +816,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -787,7 +836,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -920,7 +969,8 @@ def override( deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, - tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + tools: Sequence[Tool[AgentDepsT, CustomEventDataT] | ToolFuncEither[AgentDepsT, ..., CustomEventDataT]] + | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py index dd1f8c1ee3..9f3966d4c6 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py @@ -1,13 +1,18 @@ from __future__ import annotations -from collections.abc import Callable +import asyncio +from collections.abc import AsyncIterator, Callable +from dataclasses import replace from typing import Any, Literal +import anyio from temporalio import activity, workflow from temporalio.workflow import ActivityConfig from pydantic_ai import FunctionToolset, ToolsetTool +from pydantic_ai.agent import EventStreamHandler from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import CustomEventDataT, HandleResponseEvent from pydantic_ai.tools import AgentDepsT, RunContext from pydantic_ai.toolsets.function import FunctionToolsetTool @@ -29,14 +34,17 @@ def __init__( tool_activity_config: dict[str, ActivityConfig | Literal[False]], deps_type: type[AgentDepsT], run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT], + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ): super().__init__(toolset) self.activity_config = activity_config self.tool_activity_config = tool_activity_config self.run_context_type = run_context_type + self.event_stream_handler = event_stream_handler async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult: name = params.name + # TODO (DouweM): RunContext.event_stream -> call event_stream_handler directly? ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps) try: tool = (await toolset.get_tools(ctx))[name] @@ -49,7 +57,30 @@ async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallTo # The tool args will already have been validated into their proper types in the `ToolManager`, # but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them. args_dict = tool.args_validator.validate_python(params.tool_args) - return await self._wrap_call_tool_result(self.wrapped.call_tool(name, args_dict, ctx, tool)) + + async def _call_tool(ctx: RunContext[AgentDepsT]) -> CallToolResult: + return await self._wrap_call_tool_result(self.wrapped.call_tool(name, args_dict, ctx, tool)) + + if self.event_stream_handler is not None: + send_stream, receive_stream = anyio.create_memory_object_stream[HandleResponseEvent[CustomEventDataT]]() + + async def _call_tool_with_stream() -> CallToolResult: + async with send_stream: + ctx_with_stream = replace(ctx, event_stream=send_stream) + return await _call_tool(ctx_with_stream) + + task = asyncio.create_task(_call_tool_with_stream()) + + async def _receive_events() -> AsyncIterator[HandleResponseEvent[CustomEventDataT]]: + async with receive_stream: + async for event in receive_stream: + yield event + + await self.event_stream_handler(ctx, _receive_events()) + + return await task + else: + return await _call_tool(ctx) # Set type hint explicitly so that Temporal can take care of serialization and deserialization call_tool_activity.__annotations__['deps'] = deps_type diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py index 8a97253ee1..b31e3ca16f 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py @@ -17,6 +17,7 @@ ) from pydantic_ai.agent import EventStreamHandler from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import CustomEventDataT from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse from pydantic_ai.models.wrapper import WrapperModel from pydantic_ai.settings import ModelSettings @@ -74,7 +75,7 @@ def __init__( activity_config: ActivityConfig, deps_type: type[AgentDepsT], run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT], - event_stream_handler: EventStreamHandler[Any] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT, CustomEventDataT] | None = None, ): super().__init__(model) self.activity_config = activity_config diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py index d4adb4b6a7..ab5c09857a 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py @@ -3,17 +3,18 @@ from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Annotated, Any, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal from pydantic import ConfigDict, Discriminator, with_config -from temporalio.workflow import ActivityConfig from typing_extensions import assert_never from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry +from pydantic_ai.messages import Return, ToolReturn from pydantic_ai.tools import AgentDepsT, ToolDefinition -from ._run_context import TemporalRunContext +if TYPE_CHECKING: + from ._agent import TemporalizeToolsetContext @dataclass @@ -43,7 +44,7 @@ class _ModelRetry: @dataclass class _ToolReturn: - result: Any + result: ToolReturn[Any] | Any kind: Literal['tool_return'] = 'tool_return' @@ -74,6 +75,9 @@ def visit_and_replace( async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult: try: result = await coro + if type(result) is Return: + # We don't use `isinstance` because `ToolReturn` is a subclass of `Return` with additional fields, which should be returned in full. + result = result.return_value return _ToolReturn(result=result) except ApprovalRequired: return _ApprovalRequired() @@ -96,33 +100,20 @@ def _unwrap_call_tool_result(self, result: CallToolResult) -> Any: def temporalize_toolset( - toolset: AbstractToolset[AgentDepsT], - activity_name_prefix: str, - activity_config: ActivityConfig, - tool_activity_config: dict[str, ActivityConfig | Literal[False]], - deps_type: type[AgentDepsT], - run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT], + toolset: AbstractToolset[AgentDepsT], context: TemporalizeToolsetContext[AgentDepsT] ) -> AbstractToolset[AgentDepsT]: - """Temporalize a toolset. - - Args: - toolset: The toolset to temporalize. - activity_name_prefix: Prefix for Temporal activity names. - activity_config: The Temporal activity config to use. - tool_activity_config: The Temporal activity config to use for specific tools identified by tool name. - deps_type: The type of agent's dependencies object. It needs to be serializable using Pydantic's `TypeAdapter`. - run_context_type: The `TemporalRunContext` (sub)class that's used to serialize and deserialize the run context. - """ + """Temporalize a toolset.""" if isinstance(toolset, FunctionToolset): from ._function_toolset import TemporalFunctionToolset return TemporalFunctionToolset( toolset, - activity_name_prefix=activity_name_prefix, - activity_config=activity_config, - tool_activity_config=tool_activity_config, - deps_type=deps_type, - run_context_type=run_context_type, + activity_name_prefix=context.activity_name_prefix, + activity_config=context.activity_config, + tool_activity_config=context.tool_activity_config, + deps_type=context.deps_type, + run_context_type=context.run_context_type, + event_stream_handler=context.event_stream_handler, ) try: @@ -135,11 +126,11 @@ def temporalize_toolset( if isinstance(toolset, MCPServer): return TemporalMCPServer( toolset, - activity_name_prefix=activity_name_prefix, - activity_config=activity_config, - tool_activity_config=tool_activity_config, - deps_type=deps_type, - run_context_type=run_context_type, + activity_name_prefix=context.activity_name_prefix, + activity_config=context.activity_config, + tool_activity_config=context.tool_activity_config, + deps_type=context.deps_type, + run_context_type=context.run_context_type, ) return toolset diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index d5aaa5e791..1abec6cb15 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -7,13 +7,13 @@ from dataclasses import KW_ONLY, dataclass, field, replace from datetime import datetime from mimetypes import guess_type -from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, cast, overload +from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeAlias, cast, overload import pydantic import pydantic_core from genai_prices import calc_price, types as genai_types from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage] -from typing_extensions import deprecated +from typing_extensions import TypeAliasType, TypeVar, deprecated from . import _otel_messages, _utils from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc @@ -23,6 +23,9 @@ if TYPE_CHECKING: from .models.instrumented import InstrumentationSettings +CustomEventDataT = TypeVar('CustomEventDataT', default=object, covariant=True) +"""Covariant type variable for the data type of a custom event.""" + AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg', 'audio/ogg', 'audio/flac', 'audio/aiff', 'audio/aac'] ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp'] @@ -615,9 +618,20 @@ def __init__( MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent UserContent: TypeAlias = str | MultiModalContent +ReturnValueT = TypeVar('ReturnValueT', default=Any, covariant=True) + @dataclass(repr=False) -class ToolReturn: +class Return(Generic[ReturnValueT]): + """TODO (DouweM): Docstring.""" + + return_value: ReturnValueT + + __repr__ = _utils.dataclasses_no_defaults_repr + + +@dataclass(repr=False) +class ToolReturn(Return[ReturnValueT]): """A structured return value for tools that need to provide both a return value and custom content to the model. This class allows tools to return complex responses that include: @@ -626,9 +640,6 @@ class ToolReturn: - Optional metadata for application use """ - return_value: Any - """The return value to be used in the tool response.""" - _: KW_ONLY content: str | Sequence[UserContent] | None = None @@ -1776,14 +1787,44 @@ class BuiltinToolResultEvent: """Event type identifier, used as a discriminator.""" -HandleResponseEvent = Annotated[ - FunctionToolCallEvent - | FunctionToolResultEvent - | BuiltinToolCallEvent # pyright: ignore[reportDeprecated] - | BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] - pydantic.Discriminator('event_kind'), -] +@dataclass(repr=False) +class CustomEvent(Generic[CustomEventDataT]): + """A custom event emitted during the execution of a tool or output function.""" + + data: CustomEventDataT = None # pyright: ignore[reportAssignmentType] + """The data of the custom event.""" + + _: KW_ONLY + + name: str | None = None + """The name of the custom event.""" + + tool_call_id: str | None = None + """The tool call ID, if any, that this event is associated with.""" + + event_kind: Literal['custom'] = 'custom' + """Event type identifier, used as a discriminator.""" + + __repr__ = _utils.dataclasses_no_defaults_repr + + +HandleResponseEvent = TypeAliasType( + 'HandleResponseEvent', + Annotated[ + FunctionToolCallEvent + | FunctionToolResultEvent + | CustomEvent[CustomEventDataT] + | BuiltinToolCallEvent # pyright: ignore[reportDeprecated] + | BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] + pydantic.Discriminator('event_kind'), + ], + type_params=(CustomEventDataT,), +) """An event yielded when handling a model response, indicating tool calls and results.""" -AgentStreamEvent = Annotated[ModelResponseStreamEvent | HandleResponseEvent, pydantic.Discriminator('event_kind')] +AgentStreamEvent = TypeAliasType( + 'AgentStreamEvent', + Annotated[ModelResponseStreamEvent | HandleResponseEvent[CustomEventDataT], pydantic.Discriminator('event_kind')], + type_params=(CustomEventDataT,), +) """An event in the agent stream: model response stream events and response-handling events.""" diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index cd5e5865a6..5dda7b91f7 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import AsyncIterator, Awaitable, Callable, Sequence from dataclasses import dataclass from typing import Any, Generic, Literal @@ -11,7 +11,7 @@ from . import _utils, exceptions from ._json_schema import InlineDefsJsonSchemaTransformer -from .messages import ToolCallPart +from .messages import CustomEvent, Return, ToolCallPart from .tools import DeferredToolRequests, ObjectJsonSchema, RunContext, ToolDefinition __all__ = ( @@ -37,6 +37,9 @@ OutputDataT = TypeVar('OutputDataT', default=str, covariant=True) """Covariant type variable for the output data type of a run.""" +OutputCustomEventDataT = TypeVar('OutputCustomEventDataT', default=object, covariant=True) +"""Covariant type variable for the data type of a custom event.""" + OutputMode = Literal['text', 'tool', 'native', 'prompted', 'tool_or_text', 'image', 'auto'] """All output modes. @@ -46,9 +49,25 @@ StructuredOutputMode = Literal['tool', 'native', 'prompted'] """Output modes that can be used for structured output. Used by ModelProfile.default_structured_output_mode""" +OutputFunction = TypeAliasType( + 'OutputFunction', + Callable[..., AsyncIterator[Return[T_co] | CustomEvent[OutputCustomEventDataT]]] + | Callable[..., AsyncIterator[Return[T_co] | OutputCustomEventDataT]] + | Callable[..., Awaitable[T_co]] + | Callable[..., T_co], + type_params=(T_co, OutputCustomEventDataT), +) +"""Definition of an output function. + +You should not need to import or use this type directly. + +See [output docs](../output.md) for more information. +""" OutputTypeOrFunction = TypeAliasType( - 'OutputTypeOrFunction', type[T_co] | Callable[..., Awaitable[T_co] | T_co], type_params=(T_co,) + 'OutputTypeOrFunction', + OutputFunction[T_co, OutputCustomEventDataT] | type[T_co], + type_params=(T_co, OutputCustomEventDataT), ) """Definition of an output type or function. @@ -60,8 +79,15 @@ TextOutputFunc = TypeAliasType( 'TextOutputFunc', - Callable[[RunContext, str], Awaitable[T_co] | T_co] | Callable[[str], Awaitable[T_co] | T_co], - type_params=(T_co,), + Callable[[RunContext, str], AsyncIterator[Return[T_co] | CustomEvent[OutputCustomEventDataT]]] + | Callable[[RunContext, str], AsyncIterator[Return[T_co] | OutputCustomEventDataT]] + | Callable[[RunContext, str], Awaitable[T_co]] + | Callable[[RunContext, str], T_co] + | Callable[[str], AsyncIterator[Return[T_co] | CustomEvent[OutputCustomEventDataT]]] + | Callable[[str], AsyncIterator[Return[T_co] | OutputCustomEventDataT]] + | Callable[[str], Awaitable[T_co]] + | Callable[[str], T_co], + type_params=(T_co, OutputCustomEventDataT), ) """Definition of a function that will be called to process the model's plain text output. The function must take a single string argument. @@ -72,7 +98,7 @@ @dataclass(init=False) -class ToolOutput(Generic[OutputDataT]): +class ToolOutput(Generic[OutputDataT, OutputCustomEventDataT]): """Marker class to use a tool for output and optionally customize the tool. Example: @@ -105,7 +131,7 @@ class Vehicle(BaseModel): ``` """ - output: OutputTypeOrFunction[OutputDataT] + output: OutputTypeOrFunction[OutputDataT, OutputCustomEventDataT] """An output type or function.""" name: str | None """The name of the tool that will be passed to the model. If not specified and only one output is provided, `final_result` will be used. If multiple outputs are provided, the name of the output type or function will be added to the tool name.""" @@ -118,7 +144,7 @@ class Vehicle(BaseModel): def __init__( self, - type_: OutputTypeOrFunction[OutputDataT], + type_: OutputTypeOrFunction[OutputDataT, OutputCustomEventDataT], *, name: str | None = None, description: str | None = None, @@ -133,7 +159,7 @@ def __init__( @dataclass(init=False) -class NativeOutput(Generic[OutputDataT]): +class NativeOutput(Generic[OutputDataT, OutputCustomEventDataT]): """Marker class to use the model's native structured outputs functionality for outputs and optionally customize the name and description. Example: @@ -156,7 +182,10 @@ class NativeOutput(Generic[OutputDataT]): ``` """ - outputs: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]] + outputs: ( + OutputTypeOrFunction[OutputDataT, OutputCustomEventDataT] + | Sequence[OutputTypeOrFunction[OutputDataT, OutputCustomEventDataT]] + ) """The output types or functions.""" name: str | None """The name of the structured output that will be passed to the model. If not specified and only one output is provided, the name of the output type or function will be used.""" @@ -167,7 +196,8 @@ class NativeOutput(Generic[OutputDataT]): def __init__( self, - outputs: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]], + outputs: OutputTypeOrFunction[OutputDataT, OutputCustomEventDataT] + | Sequence[OutputTypeOrFunction[OutputDataT, OutputCustomEventDataT]], *, name: str | None = None, description: str | None = None, @@ -180,7 +210,7 @@ def __init__( @dataclass(init=False) -class PromptedOutput(Generic[OutputDataT]): +class PromptedOutput(Generic[OutputDataT, OutputCustomEventDataT]): """Marker class to use a prompt to tell the model what to output and optionally customize the prompt. Example: @@ -222,7 +252,10 @@ class Device(BaseModel): ``` """ - outputs: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]] + outputs: ( + OutputTypeOrFunction[OutputDataT, OutputCustomEventDataT] + | Sequence[OutputTypeOrFunction[OutputDataT, OutputCustomEventDataT]] + ) """The output types or functions.""" name: str | None """The name of the structured output that will be passed to the model. If not specified and only one output is provided, the name of the output type or function will be used.""" @@ -236,7 +269,8 @@ class Device(BaseModel): def __init__( self, - outputs: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]], + outputs: OutputTypeOrFunction[OutputDataT, OutputCustomEventDataT] + | Sequence[OutputTypeOrFunction[OutputDataT, OutputCustomEventDataT]], *, name: str | None = None, description: str | None = None, @@ -259,7 +293,7 @@ class OutputObjectDefinition: @dataclass -class TextOutput(Generic[OutputDataT]): +class TextOutput(Generic[OutputDataT, OutputCustomEventDataT]): """Marker class to use text output for an output function taking a string argument. Example: @@ -281,7 +315,7 @@ def split_into_words(text: str) -> list[str]: ``` """ - output_function: TextOutputFunc[OutputDataT] + output_function: TextOutputFunc[OutputDataT, OutputCustomEventDataT] """The function that will be called to process the model's plain text output. The function must take a single string argument.""" @@ -354,14 +388,18 @@ def __get_pydantic_json_schema__( _OutputSpecItem = TypeAliasType( '_OutputSpecItem', - OutputTypeOrFunction[T_co] | ToolOutput[T_co] | NativeOutput[T_co] | PromptedOutput[T_co] | TextOutput[T_co], - type_params=(T_co,), + OutputTypeOrFunction[T_co, OutputCustomEventDataT] + | ToolOutput[T_co, OutputCustomEventDataT] + | NativeOutput[T_co, OutputCustomEventDataT] + | PromptedOutput[T_co, OutputCustomEventDataT] + | TextOutput[T_co, OutputCustomEventDataT], + type_params=(T_co, OutputCustomEventDataT), ) OutputSpec = TypeAliasType( 'OutputSpec', - _OutputSpecItem[T_co] | Sequence['OutputSpec[T_co]'], - type_params=(T_co,), + _OutputSpecItem[T_co, OutputCustomEventDataT] | Sequence['OutputSpec[T_co, OutputCustomEventDataT]'], + type_params=(T_co, OutputCustomEventDataT), ) """Specification of the agent's output data. diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index da053a5191..8862bc9f85 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,18 +1,18 @@ from __future__ import annotations as _annotations -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import AsyncIterator, Awaitable, Callable, Sequence from dataclasses import KW_ONLY, dataclass, field from typing import Annotated, Any, Concatenate, Generic, Literal, TypeAlias, cast from pydantic import Discriminator, Tag from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from pydantic_core import SchemaValidator, core_schema -from typing_extensions import ParamSpec, Self, TypeVar +from typing_extensions import ParamSpec, Self, TypeAliasType, TypeVar from . import _function_schema, _utils from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry -from .messages import RetryPromptPart, ToolCallPart, ToolReturn +from .messages import CustomEvent, RetryPromptPart, Return, ToolCallPart, ToolReturn __all__ = ( 'AgentDepsT', @@ -38,6 +38,9 @@ ToolParams = ParamSpec('ToolParams', default=...) """Retrieval function param spec.""" +ToolCustomEventDataT = TypeVar('ToolCustomEventDataT', default=object, covariant=True) +"""Covariant type variable for the data type of a custom event for a tool.""" + SystemPromptFunc: TypeAlias = ( Callable[[RunContext[AgentDepsT]], str] | Callable[[RunContext[AgentDepsT]], Awaitable[str]] @@ -49,17 +52,43 @@ Usage `SystemPromptFunc[AgentDepsT]`. """ -ToolFuncContext: TypeAlias = Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Any] +ToolFuncContext = TypeAliasType( + 'ToolFuncContext', + Callable[ + Concatenate[RunContext[AgentDepsT], ToolParams], + AsyncIterator[Return[Any] | CustomEvent[ToolCustomEventDataT]], + ] + | Callable[ + Concatenate[RunContext[AgentDepsT], ToolParams], + AsyncIterator[Return[Any] | ToolCustomEventDataT], + ] + | Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Awaitable[Any]] + | Callable[Concatenate[RunContext[AgentDepsT], ToolParams], Any], + type_params=(AgentDepsT, ToolParams, ToolCustomEventDataT), +) """A tool function that takes `RunContext` as the first argument. Usage `ToolContextFunc[AgentDepsT, ToolParams]`. """ -ToolFuncPlain: TypeAlias = Callable[ToolParams, Any] +ToolFuncPlain = TypeAliasType( + 'ToolFuncPlain', + Callable[ToolParams, AsyncIterator[Return[Any] | CustomEvent[ToolCustomEventDataT]]] + | Callable[ + ToolParams, AsyncIterator[Return[Any] | ToolCustomEventDataT] + ] # TODO (DouweM): Drop one of these, make tool_stream + | Callable[ToolParams, Awaitable[Any]] + | Callable[ToolParams, Any], + type_params=(ToolParams, ToolCustomEventDataT), +) """A tool function that does not take `RunContext` as the first argument. Usage `ToolPlainFunc[ToolParams]`. """ -ToolFuncEither: TypeAlias = ToolFuncContext[AgentDepsT, ToolParams] | ToolFuncPlain[ToolParams] +ToolFuncEither = TypeAliasType( + 'ToolFuncEither', + ToolFuncContext[AgentDepsT, ToolParams, ToolCustomEventDataT] | ToolFuncPlain[ToolParams, ToolCustomEventDataT], + type_params=(AgentDepsT, ToolParams, ToolCustomEventDataT), +) """Either kind of tool function. This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and @@ -243,12 +272,15 @@ def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[st ToolAgentDepsT = TypeVar('ToolAgentDepsT', default=object, contravariant=True) """Type variable for agent dependencies for a tool.""" +ToolCustomEventDataT = TypeVar('ToolCustomEventDataT', default=object, covariant=True) +"""Covariant type variable for the data type of a custom event for a tool.""" + @dataclass(init=False) -class Tool(Generic[ToolAgentDepsT]): +class Tool(Generic[ToolAgentDepsT, ToolCustomEventDataT]): """A tool function for an agent.""" - function: ToolFuncEither[ToolAgentDepsT] + function: ToolFuncEither[ToolAgentDepsT, ..., ToolCustomEventDataT] takes_ctx: bool max_retries: int | None name: str @@ -269,7 +301,7 @@ class Tool(Generic[ToolAgentDepsT]): def __init__( self, - function: ToolFuncEither[ToolAgentDepsT], + function: ToolFuncEither[ToolAgentDepsT, ..., ToolCustomEventDataT], *, takes_ctx: bool | None = None, max_retries: int | None = None, @@ -395,6 +427,7 @@ def from_schema( json_schema=json_schema, takes_ctx=takes_ctx, is_async=_utils.is_async_callable(function), + is_async_iterator=_utils.is_async_iterator_callable(function), ) return cls( diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py index e185ed0273..2b4d854e8f 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/function.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -42,6 +42,7 @@ class FunctionToolset(AbstractToolset[AgentDepsT]): def __init__( self, + # TODO (DouweM): Use CustomEventDataT here? tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], *, max_retries: int = 1, @@ -339,7 +340,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[ max_retries=max_retries, args_validator=tool.function_schema.validator, call_func=tool.function_schema.call, - is_async=tool.function_schema.is_async, + is_async=tool.function_schema.is_async or tool.function_schema.is_async_iterator, ) return tools diff --git a/pydantic_ai_slim/pydantic_ai/ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/_adapter.py index 970f06e6ef..c71e18dda5 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/_adapter.py @@ -12,6 +12,7 @@ Generic, Protocol, TypeVar, + overload, runtime_checkable, ) @@ -19,10 +20,10 @@ from pydantic_ai import DeferredToolRequests, DeferredToolResults from pydantic_ai.agent import AbstractAgent -from pydantic_ai.agent.abstract import Instructions +from pydantic_ai.agent.abstract import Instructions, RunOutputDataT from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_ai.exceptions import UserError -from pydantic_ai.messages import ModelMessage +from pydantic_ai.messages import CustomEventDataT, ModelMessage from pydantic_ai.models import KnownModelName, Model from pydantic_ai.output import OutputDataT, OutputSpec from pydantic_ai.settings import ModelSettings @@ -166,7 +167,7 @@ def state(self) -> dict[str, Any] | None: def transform_stream( self, - stream: AsyncIterator[NativeEvent], + stream: AsyncIterator[NativeEvent[OutputDataT]], on_complete: OnCompleteFunc[EventT] | None = None, ) -> AsyncIterator[EventT]: """Transform a stream of Pydantic AI events into protocol-specific events. @@ -194,10 +195,11 @@ def streaming_response(self, stream: AsyncIterator[EventT]) -> StreamingResponse """ return self.build_event_stream().streaming_response(stream) + @overload def run_stream_native( self, *, - output_type: OutputSpec[Any] | None = None, + output_type: None = None, message_history: Sequence[ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: Model | KnownModelName | str | None = None, @@ -209,7 +211,42 @@ def run_stream_native( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, - ) -> AsyncIterator[NativeEvent]: + ) -> AsyncIterator[NativeEvent[OutputDataT, CustomEventDataT]]: ... + + @overload + def run_stream_native( + self, + *, + output_type: OutputSpec[RunOutputDataT, CustomEventDataT], + message_history: Sequence[ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: Model | KnownModelName | str | None = None, + instructions: Instructions[AgentDepsT] = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + ) -> AsyncIterator[NativeEvent[RunOutputDataT, CustomEventDataT]]: ... + + def run_stream_native( + self, + *, + output_type: OutputSpec[Any, CustomEventDataT] | None = None, + message_history: Sequence[ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: Model | KnownModelName | str | None = None, + instructions: Instructions[AgentDepsT] = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + ) -> AsyncIterator[NativeEvent[Any, CustomEventDataT]]: """Run the agent with the protocol-specific run input and stream Pydantic AI events. Args: @@ -265,7 +302,7 @@ def run_stream_native( def run_stream( self, *, - output_type: OutputSpec[Any] | None = None, + output_type: OutputSpec[Any, Any] | None = None, message_history: Sequence[ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: Model | KnownModelName | str | None = None, @@ -327,7 +364,7 @@ async def dispatch_request( model: Model | KnownModelName | str | None = None, instructions: Instructions[AgentDepsT] = None, deps: AgentDepsT = None, - output_type: OutputSpec[Any] | None = None, + output_type: OutputSpec[Any, Any] | None = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, usage: RunUsage | None = None, diff --git a/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py index 391cf06f2f..2046554104 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/_event_stream.py @@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast from uuid import uuid4 +from typing_extensions import TypeAliasType + from pydantic_ai import _utils from ..messages import ( @@ -15,6 +17,8 @@ BuiltinToolCallPart, BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] BuiltinToolReturnPart, + CustomEvent, + CustomEventDataT, FilePart, FinalResultEvent, FunctionToolCallEvent, @@ -47,7 +51,11 @@ RunInputT = TypeVar('RunInputT') """Type variable for protocol-specific run input types.""" -NativeEvent: TypeAlias = AgentStreamEvent | AgentRunResultEvent[Any] +NativeEvent = TypeAliasType( + 'NativeEvent', + AgentStreamEvent[CustomEventDataT] | AgentRunResultEvent[OutputDataT], + type_params=(OutputDataT, CustomEventDataT), +) """Type alias for the native event type, which is either an `AgentStreamEvent` or an `AgentRunResultEvent`.""" OnCompleteFunc: TypeAlias = ( @@ -124,7 +132,7 @@ def streaming_response(self, stream: AsyncIterator[EventT]) -> StreamingResponse ) async def transform_stream( # noqa: C901 - self, stream: AsyncIterator[NativeEvent], on_complete: OnCompleteFunc[EventT] | None = None + self, stream: AsyncIterator[NativeEvent[OutputDataT]], on_complete: OnCompleteFunc[EventT] | None = None ) -> AsyncIterator[EventT]: """Transform a stream of Pydantic AI events into protocol-specific events. @@ -229,7 +237,7 @@ async def _turn_to(self, to_turn: Literal['request', 'response'] | None) -> Asyn async for e in self.before_response(): yield e - async def handle_event(self, event: NativeEvent) -> AsyncIterator[EventT]: + async def handle_event(self, event: NativeEvent[OutputDataT]) -> AsyncIterator[EventT]: # noqa: C901 """Transform a Pydantic AI event into one or more protocol-specific events. This method dispatches to specific `handle_*` methods based on event type: @@ -240,6 +248,7 @@ async def handle_event(self, event: NativeEvent) -> AsyncIterator[EventT]: - [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] -> `handle_final_result` - [`FunctionToolCallEvent`][pydantic_ai.messages.FunctionToolCallEvent] -> `handle_function_tool_call` - [`FunctionToolResultEvent`][pydantic_ai.messages.FunctionToolResultEvent] -> `handle_function_tool_result` + - [`CustomEvent`][pydantic_ai.messages.CustomEvent] -> `handle_custom_event` - [`AgentRunResultEvent`][pydantic_ai.run.AgentRunResultEvent] -> `handle_run_result` Subclasses are encouraged to override the individual `handle_*` methods rather than this one. @@ -264,6 +273,9 @@ async def handle_event(self, event: NativeEvent) -> AsyncIterator[EventT]: case FunctionToolResultEvent(): async for e in self.handle_function_tool_result(event): yield e + case CustomEvent(): + async for e in self.handle_custom_event(event): + yield e case AgentRunResultEvent(): async for e in self.handle_run_result(event): yield e @@ -581,6 +593,15 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A return # pragma: no cover yield # Make this an async generator + async def handle_custom_event(self, event: CustomEvent) -> AsyncIterator[EventT]: + """Handle a `CustomEvent`. + + Args: + event: The custom event. + """ + return + yield # Make this an async generator + async def handle_run_result(self, event: AgentRunResultEvent) -> AsyncIterator[EventT]: """Handle an `AgentRunResultEvent`. diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py index 2b37d36351..6bd737fb8d 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py @@ -14,6 +14,7 @@ from ...messages import ( BuiltinToolCallPart, BuiltinToolReturnPart, + CustomEvent, FunctionToolResultEvent, RetryPromptPart, TextPart, @@ -31,6 +32,7 @@ try: from ag_ui.core import ( BaseEvent, + CustomEvent as AGUICustomEvent, EventType, RunAgentInput, RunErrorEvent, @@ -234,3 +236,12 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A for item in possible_event: # type: ignore[reportUnknownMemberType] if isinstance(item, BaseEvent): # pragma: no branch yield item + + async def handle_custom_event(self, event: CustomEvent) -> AsyncIterator[BaseEvent]: + if isinstance(event.data, BaseEvent): + yield event.data + elif event.name: + data = event.data + if event.tool_call_id: + data = {'tool_call_id': event.tool_call_id, 'data': data} + yield AGUICustomEvent(name=event.name, value=data) diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py index 1f0fbe5262..ba89bc528e 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py @@ -11,7 +11,7 @@ from pydantic_ai import DeferredToolResults from pydantic_ai.agent import AbstractAgent from pydantic_ai.builtin_tools import AbstractBuiltinTool -from pydantic_ai.messages import ModelMessage +from pydantic_ai.messages import CustomEventDataT, ModelMessage from pydantic_ai.models import KnownModelName, Model from pydantic_ai.output import OutputDataT, OutputSpec from pydantic_ai.settings import ModelSettings @@ -41,10 +41,10 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette): def __init__( self, - agent: AbstractAgent[AgentDepsT, OutputDataT], + agent: AbstractAgent[AgentDepsT, OutputDataT, CustomEventDataT], *, # AGUIAdapter.dispatch_request parameters - output_type: OutputSpec[Any] | None = None, + output_type: OutputSpec[Any, CustomEventDataT] | None = None, message_history: Sequence[ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: Model | KnownModelName | str | None = None, diff --git a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py index b3a0e79f5c..c4032a9563 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py @@ -11,6 +11,7 @@ from ...messages import ( BuiltinToolCallPart, BuiltinToolReturnPart, + CustomEvent, FilePart, FunctionToolResultEvent, RetryPromptPart, @@ -27,6 +28,7 @@ from .request_types import RequestData from .response_types import ( BaseChunk, + DataChunk, DoneChunk, ErrorChunk, FileChunk, @@ -185,3 +187,12 @@ async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> A yield ToolOutputAvailableChunk(tool_call_id=result.tool_call_id, output=result.content) # ToolCallResultEvent.content may hold user parts (e.g. text, images) that Vercel AI does not currently have events for + + async def handle_custom_event(self, event: CustomEvent) -> AsyncIterator[BaseChunk]: + if isinstance(event.data, BaseChunk): + yield event.data + elif event.name: + data = event.data + if event.tool_call_id: + data = {'tool_call_id': event.tool_call_id, 'data': data} + yield DataChunk(type=f'data-{event.name}', data=data) diff --git a/tests/cassettes/test_temporal/test_temporal_agent_sync_tool_activity_disabled.yaml b/tests/cassettes/test_temporal/test_temporal_agent_sync_tool_activity_disabled.yaml index 55e74d9003..95ce111556 100644 --- a/tests/cassettes/test_temporal/test_temporal_agent_sync_tool_activity_disabled.yaml +++ b/tests/cassettes/test_temporal/test_temporal_agent_sync_tool_activity_disabled.yaml @@ -8,7 +8,7 @@ interactions: connection: - keep-alive content-length: - - '346' + - '354' content-type: - application/json host: @@ -24,7 +24,7 @@ interactions: tools: - function: description: '' - name: get_weather + name: get_weather_in_city parameters: additionalProperties: false properties: @@ -45,13 +45,13 @@ interactions: connection: - keep-alive content-length: - - '1085' + - '1093' content-type: - application/json openai-organization: - pydantic-28gund openai-processing-ms: - - '806' + - '743' openai-project: - proj_dKobscVY9YJxeEaDJen54e3d openai-version: @@ -73,27 +73,27 @@ interactions: tool_calls: - function: arguments: '{"city":"Mexico City"}' - name: get_weather - id: call_MOtXZsU6lfOmXwoBOtXKpCth + name: get_weather_in_city + id: call_K6neuEWdlMH2GNNSE72aDwBH type: function - created: 1754922581 - id: chatcmpl-C3NorJhLoxFbVB2hq1bUQDmkI1TEf + created: 1762473882 + id: chatcmpl-CZ4G27JasXrHKU1cIU7EygnodNFUQ model: gpt-4o-2024-08-06 object: chat.completion service_tier: default - system_fingerprint: fp_07871e2ad8 + system_fingerprint: fp_cbf1785567 usage: - completion_tokens: 15 + completion_tokens: 17 completion_tokens_details: accepted_prediction_tokens: 0 audio_tokens: 0 reasoning_tokens: 0 rejected_prediction_tokens: 0 - prompt_tokens: 45 + prompt_tokens: 47 prompt_tokens_details: audio_tokens: 0 cached_tokens: 0 - total_tokens: 60 + total_tokens: 64 status: code: 200 message: OK diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 05071d2259..0662ebe5e5 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -199,6 +199,20 @@ async def send_custom() -> ToolReturn: ) +async def yield_custom() -> AsyncIterator[CustomEvent | ToolReturn]: + yield CustomEvent( + type=EventType.CUSTOM, + name='custom_event1', + value={'key1': 'value1'}, + ) + yield CustomEvent( + type=EventType.CUSTOM, + name='custom_event2', + value={'key2': 'value2'}, + ) + yield ToolReturn('Done') + + def uuid_str() -> str: """Generate a random UUID string.""" return uuid.uuid4().hex @@ -833,6 +847,73 @@ async def stream_function( ) +async def test_tool_local_yield_events() -> None: + """Test local tool call that yields multiple events.""" + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + if len(messages) == 1: + # First call - make a tool call + yield {0: DeltaToolCall(name='yield_custom')} + yield {0: DeltaToolCall(json_args='{}')} + else: + # Second call - return text result + yield 'success yield_custom called' + + agent = Agent( + model=FunctionModel(stream_function=stream_function), + tools=[yield_custom], + ) + + run_input = create_input( + UserMessage( + id='msg_1', + content='Please call yield_custom', + ), + ) + events = await run_and_collect_events(agent, run_input) + + assert events == snapshot( + [ + { + 'type': 'RUN_STARTED', + 'threadId': (thread_id := IsSameStr()), + 'runId': (run_id := IsSameStr()), + }, + { + 'type': 'TOOL_CALL_START', + 'toolCallId': (tool_call_id := IsSameStr()), + 'toolCallName': 'yield_custom', + 'parentMessageId': IsStr(), + }, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': tool_call_id, 'delta': '{}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': tool_call_id}, + {'type': 'CUSTOM', 'name': 'custom_event1', 'value': {'key1': 'value1'}}, + {'type': 'CUSTOM', 'name': 'custom_event2', 'value': {'key2': 'value2'}}, + { + 'type': 'TOOL_CALL_RESULT', + 'messageId': IsStr(), + 'toolCallId': tool_call_id, + 'content': 'Done', + 'role': 'tool', + }, + {'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'}, + { + 'type': 'TEXT_MESSAGE_CONTENT', + 'messageId': message_id, + 'delta': 'success yield_custom called', + }, + {'type': 'TEXT_MESSAGE_END', 'messageId': message_id}, + { + 'type': 'RUN_FINISHED', + 'threadId': thread_id, + 'runId': run_id, + }, + ] + ) + + async def test_tool_local_parts() -> None: """Test local tool call with streaming/parts.""" diff --git a/tests/test_agent.py b/tests/test_agent.py index 0a6bf1e325..b528c66944 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -19,12 +19,17 @@ from pydantic_ai import ( AbstractToolset, Agent, + AgentRunResultEvent, AgentStreamEvent, AudioUrl, BinaryContent, BinaryImage, CombinedToolset, + CustomEvent, DocumentUrl, + FinalResultEvent, + FunctionToolCallEvent, + FunctionToolResultEvent, FunctionToolset, ImageUrl, IncompleteToolCall, @@ -35,11 +40,16 @@ ModelResponse, ModelResponsePart, ModelRetry, + PartDeltaEvent, + PartEndEvent, + PartStartEvent, PrefixedToolset, RetryPromptPart, + Return, RunContext, SystemPromptPart, TextPart, + TextPartDelta, ToolCallPart, ToolReturn, ToolReturnPart, @@ -3791,7 +3801,6 @@ def get_image() -> BinaryContent: BinaryContent( data=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82', media_type='image/png', - _identifier='image_id_1', ), ], timestamp=IsNow(tz=timezone.utc), @@ -4369,7 +4378,7 @@ def analyze_data() -> list[Any]: with pytest.raises( UserError, - match="The return value of tool 'analyze_data' contains invalid nested `ToolReturn` objects. `ToolReturn` should be used directly.", + match="The return value of tool 'analyze_data' contains invalid nested `Return` objects. `Return` should be used directly.", ): agent.run_sync('Please analyze the data') @@ -5878,3 +5887,331 @@ def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ] ) assert run.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",') + + +async def test_agent_custom_events(): + agent = Agent('test') + + @agent.tool_plain + async def roll_dice() -> AsyncIterator[str | ToolReturn[int]]: + yield 'Considering 1...' + yield 'Considering 2...' + yield 'Considering 3...' + yield 'Considering 4...' + yield 'Considering 5...' + yield 'Considering 6...' + yield 'Choosing 4...' + yield ToolReturn(return_value=4) + + events = [event async for event in agent.run_stream_events('Roll me a dice.')] + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice'), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice'), + ), + FunctionToolCallEvent( + part=ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice') + ), + CustomEvent(data='Considering 1...', tool_call_id='pyd_ai_tool_call_id__roll_dice'), + CustomEvent(data='Considering 2...', tool_call_id='pyd_ai_tool_call_id__roll_dice'), + CustomEvent(data='Considering 3...', tool_call_id='pyd_ai_tool_call_id__roll_dice'), + CustomEvent(data='Considering 4...', tool_call_id='pyd_ai_tool_call_id__roll_dice'), + CustomEvent(data='Considering 5...', tool_call_id='pyd_ai_tool_call_id__roll_dice'), + CustomEvent(data='Considering 6...', tool_call_id='pyd_ai_tool_call_id__roll_dice'), + CustomEvent(data='Choosing 4...', tool_call_id='pyd_ai_tool_call_id__roll_dice'), + FunctionToolResultEvent( + result=ToolReturnPart( + tool_name='roll_dice', + content=4, + tool_call_id='pyd_ai_tool_call_id__roll_dice', + timestamp=IsDatetime(), + ) + ), + PartStartEvent(index=0, part=TextPart(content='')), + FinalResultEvent(tool_name=None, tool_call_id=None), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='{"roll_')), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='dice":4}')), + PartEndEvent(index=0, part=TextPart(content='{"roll_dice":4}')), + AgentRunResultEvent(result=AgentRunResult(output='{"roll_dice":4}')), + ] + ) + + +async def test_agent_custom_events_model_retry(): + agent = Agent('test') + + @agent.tool + async def roll_dice(ctx: RunContext) -> AsyncIterator[str | ToolReturn[int]]: + yield 'Considering 1...' + yield 'Considering 2...' + yield 'Considering 3...' + if ctx.retry == 0: + raise ModelRetry(message='Roll again.') + else: + yield ToolReturn(return_value=4) + + events = [event async for event in agent.run_stream_events('Roll me a dice.')] + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice'), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice'), + ), + FunctionToolCallEvent( + part=ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice') + ), + CustomEvent(data='Considering 1...', tool_call_id='pyd_ai_tool_call_id__roll_dice'), + CustomEvent(data='Considering 2...', tool_call_id='pyd_ai_tool_call_id__roll_dice'), + CustomEvent(data='Considering 3...', tool_call_id='pyd_ai_tool_call_id__roll_dice'), + FunctionToolResultEvent( + result=RetryPromptPart( + content='Roll again.', + tool_name='roll_dice', + tool_call_id='pyd_ai_tool_call_id__roll_dice', + timestamp=IsDatetime(), + ) + ), + PartStartEvent( + index=0, + part=ToolCallPart(tool_name='roll_dice', args={}, tool_call_id=IsStr()), + ), + PartEndEvent( + index=0, + part=ToolCallPart(tool_name='roll_dice', args={}, tool_call_id=IsStr()), + ), + FunctionToolCallEvent(part=ToolCallPart(tool_name='roll_dice', args={}, tool_call_id=IsStr())), + CustomEvent(data='Considering 1...', tool_call_id=IsStr()), + CustomEvent(data='Considering 2...', tool_call_id=IsStr()), + CustomEvent(data='Considering 3...', tool_call_id=IsStr()), + FunctionToolResultEvent( + result=ToolReturnPart( + tool_name='roll_dice', + content=4, + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ), + PartStartEvent(index=0, part=TextPart(content='')), + FinalResultEvent(tool_name=None, tool_call_id=None), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='{"roll_')), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='dice":4}')), + PartEndEvent(index=0, part=TextPart(content='{"roll_dice":4}')), + AgentRunResultEvent(result=AgentRunResult(output='{"roll_dice":4}')), + ] + ) + + +async def test_agent_custom_events_tool_output_function(): + class Weather(BaseModel): + temperature: float + description: str + + async def get_weather(city: str) -> AsyncIterator[str | Return[Weather]]: + yield 'Getting weather...' + yield Return(Weather(temperature=28.7, description='sunny')) + + agent = Agent('test', output_type=ToolOutput(get_weather)) + + events: list[AgentStreamEvent] = [] + result: AgentRunResult[Weather] | None = None + async for event in agent.run_stream_events('What is the weather in Mexico City?'): + if isinstance(event, AgentRunResultEvent): + result = event.result + else: + events.append(event) + + assert result is not None + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=ToolCallPart( + tool_name='final_result', args={'city': 'a'}, tool_call_id='pyd_ai_tool_call_id__final_result' + ), + ), + FinalResultEvent(tool_name='final_result', tool_call_id='pyd_ai_tool_call_id__final_result'), + PartEndEvent( + index=0, + part=ToolCallPart( + tool_name='final_result', args={'city': 'a'}, tool_call_id='pyd_ai_tool_call_id__final_result' + ), + ), + CustomEvent(data='Getting weather...', tool_call_id='pyd_ai_tool_call_id__final_result'), + ] + ) + + +async def test_agent_custom_events_native_output_function(): + class Weather(BaseModel): + temperature: float + description: str + + async def get_weather(city: str) -> AsyncIterator[str | Return[Weather]]: + yield 'Getting weather...' + yield Return(Weather(temperature=28.7, description='sunny')) + + async def return_city(messages: list[ModelMessage], _info: AgentInfo) -> AsyncIterator[str]: + yield '{"city":' + yield ' "Mexico City"}' + + model = FunctionModel(stream_function=return_city) + + agent = Agent(model, output_type=NativeOutput(get_weather)) + + events: list[AgentStreamEvent] = [] + result: AgentRunResult[Weather] | None = None + async for event in agent.run_stream_events('What is the weather in Mexico City?'): + if isinstance(event, AgentRunResultEvent): + result = event.result + else: + events.append(event) + + assert result is not None + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=TextPart(content='{"city":'), + ), + FinalResultEvent(tool_name=None, tool_call_id=None), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' "Mexico City"}')), + PartEndEvent(index=0, part=TextPart(content='{"city": "Mexico City"}')), + CustomEvent(data='Getting weather...'), + ] + ) + + +async def test_agent_custom_events_prompted_output_function(): + class Weather(BaseModel): + temperature: float + description: str + + async def get_weather(city: str) -> AsyncIterator[str | Return[Weather]]: + yield 'Getting weather...' + yield Return(Weather(temperature=28.7, description='sunny')) + + async def return_city(messages: list[ModelMessage], _info: AgentInfo) -> AsyncIterator[str]: + yield '{"city":' + yield ' "Mexico City"}' + + model = FunctionModel(stream_function=return_city) + + agent = Agent(model, output_type=PromptedOutput(get_weather)) + + events: list[AgentStreamEvent] = [] + result: AgentRunResult[Weather] | None = None + async for event in agent.run_stream_events('What is the weather in Mexico City?'): + if isinstance(event, AgentRunResultEvent): + result = event.result + else: + events.append(event) + + assert result is not None + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=TextPart(content='{"city":'), + ), + FinalResultEvent(tool_name=None, tool_call_id=None), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' "Mexico City"}')), + PartEndEvent(index=0, part=TextPart(content='{"city": "Mexico City"}')), + CustomEvent(data='Getting weather...'), + ] + ) + + +async def test_agent_custom_events_text_output_function(): + class Weather(BaseModel): + temperature: float + description: str + + async def get_weather(city: str) -> AsyncIterator[str | Return[Weather]]: + yield 'Getting weather...' + yield Return(Weather(temperature=28.7, description='sunny')) + + async def return_city(messages: list[ModelMessage], _info: AgentInfo) -> AsyncIterator[str]: + yield 'Mexico' + yield ' City' + + model = FunctionModel(stream_function=return_city) + + agent = Agent(model, output_type=TextOutput(get_weather)) + + events: list[AgentStreamEvent] = [] + result: AgentRunResult[Weather] | None = None + async for event in agent.run_stream_events('What is the weather in Mexico City?'): + if isinstance(event, AgentRunResultEvent): + result = event.result + else: + events.append(event) + + assert result is not None + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=TextPart(content='Mexico'), + ), + FinalResultEvent(tool_name=None, tool_call_id=None), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' City')), + PartEndEvent(index=0, part=TextPart(content='Mexico City')), + CustomEvent(data='Getting weather...'), + ] + ) + + +async def test_agent_custom_events_run_stream(): + class Weather(BaseModel): + temperature: float + description: str + + async def get_weather(city: str) -> AsyncIterator[str | Return[Weather]]: + yield 'Getting weather...' + yield Return(Weather(temperature=28.7, description='sunny')) + + async def return_city(messages: list[ModelMessage], _info: AgentInfo) -> AsyncIterator[str]: + yield '{"city":' + yield ' "Mexico City"}' + + model = FunctionModel(stream_function=return_city) + + agent = Agent(model, output_type=NativeOutput(get_weather)) + + events: list[AgentStreamEvent] = [] + + async def event_stream_handler(ctx: RunContext, stream: AsyncIterable[AgentStreamEvent]): + async for event in stream: + events.append(event) + + outputs: list[Weather] = [] + async with agent.run_stream( + 'What is the weather in Mexico City?', event_stream_handler=event_stream_handler + ) as run: + async for output in run.stream_output(): + outputs.append(output) + + assert await run.get_output() == snapshot(Weather(temperature=28.7, description='sunny')) + assert outputs == snapshot([Weather(temperature=28.7, description='sunny')]) + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=TextPart(content='{"city":'), + ), + FinalResultEvent(tool_name=None, tool_call_id=None), + ] + ) diff --git a/tests/test_dbos.py b/tests/test_dbos.py index 256aba83fb..29281f7ae3 100644 --- a/tests/test_dbos.py +++ b/tests/test_dbos.py @@ -266,7 +266,6 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D 'complex_agent__model.request_stream', 'event_stream_handler', 'event_stream_handler', - 'event_stream_handler', 'complex_agent__mcp_server__mcp.call_tool', 'event_stream_handler', 'complex_agent__mcp_server__mcp.get_tools', @@ -361,16 +360,9 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D content='running 2 tools', children=[ BasicSpan(content='running tool: get_country'), + BasicSpan(content='ctx.run_step=1'), BasicSpan( - content='event_stream_handler', - children=[ - BasicSpan(content='ctx.run_step=1'), - BasicSpan( - content=IsStr( - regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' - ) - ), - ], + content='{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":"2025-10-08T14:38:30.370338+00:00","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' ), BasicSpan( content='running tool: get_product_name', diff --git a/tests/test_temporal.py b/tests/test_temporal.py index b3fe75d911..780951f2b6 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -17,6 +17,7 @@ AgentRunResultEvent, AgentStreamEvent, BinaryImage, + CustomEvent, ExternalToolset, FinalResultEvent, FunctionToolCallEvent, @@ -30,11 +31,13 @@ PartEndEvent, PartStartEvent, RetryPromptPart, + Return, RunContext, TextPart, TextPartDelta, ToolCallPart, ToolCallPartDelta, + ToolReturn, ToolReturnPart, UserPromptPart, WebSearchTool, @@ -230,11 +233,12 @@ class WeatherArgs(BaseModel): city: str -def get_weather(args: WeatherArgs) -> str: +async def get_weather(args: WeatherArgs) -> AsyncIterator[str | Return[str]]: + yield 'Getting weather...' if args.city == 'Mexico City': - return 'sunny' + yield ToolReturn('sunny', metadata={'temperature': 28.7}) else: - return 'unknown' # pragma: no cover + yield Return('unknown') # pragma: no cover @dataclass @@ -248,10 +252,15 @@ class Response: answers: list[Answer] +async def process_output(output: Response) -> AsyncIterator[str | Return[Response]]: + yield 'Processing output...' + yield Return(output) + + complex_agent = Agent( model, deps_type=Deps, - output_type=Response, + output_type=process_output, toolsets=[ FunctionToolset[Deps](tools=[get_country], id='country'), MCPServerStdio('python', ['-m', 'tests.mcp_server'], timeout=20, id='mcp'), @@ -423,25 +432,25 @@ async def test_complex_agent_run_in_workflow( ], ), BasicSpan( - content='running 2 tools', + content='StartActivity:agent__complex_agent__event_stream_handler', children=[ - BasicSpan(content='running tool: get_country'), BasicSpan( - content='StartActivity:agent__complex_agent__event_stream_handler', + content='RunActivity:agent__complex_agent__event_stream_handler', children=[ + BasicSpan(content='ctx.run_step=1'), BasicSpan( - content='RunActivity:agent__complex_agent__event_stream_handler', - children=[ - BasicSpan(content='ctx.run_step=1'), - BasicSpan( - content=IsStr( - regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' - ) - ), - ], - ) + content=IsStr( + regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + ) + ), ], - ), + ) + ], + ), + BasicSpan( + content='running 2 tools', + children=[ + BasicSpan(content='running tool: get_country'), BasicSpan( content='running tool: get_product_name', children=[ @@ -455,22 +464,22 @@ async def test_complex_agent_run_in_workflow( ) ], ), + ], + ), + BasicSpan( + content='StartActivity:agent__complex_agent__event_stream_handler', + children=[ BasicSpan( - content='StartActivity:agent__complex_agent__event_stream_handler', + content='RunActivity:agent__complex_agent__event_stream_handler', children=[ + BasicSpan(content='ctx.run_step=1'), BasicSpan( - content='RunActivity:agent__complex_agent__event_stream_handler', - children=[ - BasicSpan(content='ctx.run_step=1'), - BasicSpan( - content=IsStr( - regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' - ) - ), - ], - ) + content=IsStr( + regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + ) + ), ], - ), + ) ], ), BasicSpan( @@ -543,28 +552,34 @@ async def test_complex_agent_run_in_workflow( content='StartActivity:agent__complex_agent__toolset____call_tool', children=[ BasicSpan( - content='RunActivity:agent__complex_agent__toolset____call_tool' + content='RunActivity:agent__complex_agent__toolset____call_tool', + children=[ + BasicSpan(content='ctx.run_step=2'), + BasicSpan( + content='{"data":"Getting weather...","name":null,"tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","event_kind":"custom"}' + ), + ], ) ], ) ], - ), + ) + ], + ), + BasicSpan( + content='StartActivity:agent__complex_agent__event_stream_handler', + children=[ BasicSpan( - content='StartActivity:agent__complex_agent__event_stream_handler', + content='RunActivity:agent__complex_agent__event_stream_handler', children=[ + BasicSpan(content='ctx.run_step=2'), BasicSpan( - content='RunActivity:agent__complex_agent__event_stream_handler', - children=[ - BasicSpan(content='ctx.run_step=2'), - BasicSpan( - content=IsStr( - regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' - ) - ), - ], - ) + content=IsStr( + regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":{"temperature":28.7},"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' + ) + ), ], - ), + ) ], ), BasicSpan( @@ -718,6 +733,21 @@ async def test_complex_agent_run_in_workflow( ) ], ), + BasicSpan(content='running output function: final_result'), + BasicSpan( + content='StartActivity:agent__complex_agent__event_stream_handler', + children=[ + BasicSpan( + content='RunActivity:agent__complex_agent__event_stream_handler', + children=[ + BasicSpan(content='ctx.run_step=3'), + BasicSpan( + content='{"data":"Processing output...","name":null,"tool_call_id":"call_4kc6691zCzjPnOuEtbEGUvz2","event_kind":"custom"}' + ), + ], + ) + ], + ), ], ), BasicSpan(content='CompleteWorkflow:ComplexAgentWorkflow'), @@ -834,11 +864,13 @@ async def event_stream_handler( tool_name='get_weather', args='{"city":"Mexico City"}', tool_call_id='call_LwxJUB9KppVyogRRLQsamRJv' ) ), + CustomEvent(data='Getting weather...', tool_call_id='call_LwxJUB9KppVyogRRLQsamRJv'), FunctionToolResultEvent( result=ToolReturnPart( tool_name='get_weather', content='sunny', tool_call_id='call_LwxJUB9KppVyogRRLQsamRJv', + metadata={'temperature': 28.7}, timestamp=IsDatetime(), ) ), @@ -1014,6 +1046,7 @@ async def event_stream_handler( tool_call_id='call_CCGIWaMeYWmxOQ91orkmTvzn', ), ), + CustomEvent(data='Processing output...', tool_call_id='call_CCGIWaMeYWmxOQ91orkmTvzn'), ] ) @@ -1530,49 +1563,6 @@ async def test_temporal_agent_override_deps_in_workflow(allow_model_requests: No assert output == snapshot('The capital of Mexico is Mexico City.') -agent_with_sync_tool = Agent(model, name='agent_with_sync_tool', tools=[get_weather]) - -# This needs to be done before the `TemporalAgent` is bound to the workflow. -temporal_agent_with_sync_tool_activity_disabled = TemporalAgent( - agent_with_sync_tool, - activity_config=BASE_ACTIVITY_CONFIG, - tool_activity_config={ - '': { - 'get_weather': False, - }, - }, -) - - -@workflow.defn -class AgentWorkflowWithSyncToolActivityDisabled: - @workflow.run - async def run(self, prompt: str) -> str: - result = await temporal_agent_with_sync_tool_activity_disabled.run(prompt) - return result.output # pragma: no cover - - -async def test_temporal_agent_sync_tool_activity_disabled(allow_model_requests: None, client: Client): - async with Worker( - client, - task_queue=TASK_QUEUE, - workflows=[AgentWorkflowWithSyncToolActivityDisabled], - plugins=[AgentPlugin(temporal_agent_with_sync_tool_activity_disabled)], - ): - with workflow_raises( - UserError, - snapshot( - "Temporal activity config for tool 'get_weather' has been explicitly set to `False` (activity disabled), but non-async tools are run in threads which are not supported outside of an activity. Make the tool function async instead." - ), - ): - await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] - AgentWorkflowWithSyncToolActivityDisabled.run, - args=['What is the weather in Mexico City?'], - id=AgentWorkflowWithSyncToolActivityDisabled.__name__, - task_queue=TASK_QUEUE, - ) - - async def test_temporal_agent_mcp_server_activity_disabled(client: Client): with pytest.raises( UserError, @@ -2028,6 +2018,49 @@ async def test_temporal_agent_with_model_retry(allow_model_requests: None, clien ) +agent_with_sync_tool = Agent(model, name='agent_with_sync_tool', tools=[get_weather_in_city]) + +# This needs to be done before the `TemporalAgent` is bound to the workflow. +temporal_agent_with_sync_tool_activity_disabled = TemporalAgent( + agent_with_sync_tool, + activity_config=BASE_ACTIVITY_CONFIG, + tool_activity_config={ + '': { + 'get_weather_in_city': False, + }, + }, +) + + +@workflow.defn +class AgentWorkflowWithSyncToolActivityDisabled: + @workflow.run + async def run(self, prompt: str) -> str: + result = await temporal_agent_with_sync_tool_activity_disabled.run(prompt) + return result.output # pragma: no cover + + +async def test_temporal_agent_sync_tool_activity_disabled(allow_model_requests: None, client: Client): + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[AgentWorkflowWithSyncToolActivityDisabled], + plugins=[AgentPlugin(temporal_agent_with_sync_tool_activity_disabled)], + ): + with workflow_raises( + UserError, + snapshot( + "Temporal activity config for tool 'get_weather_in_city' has been explicitly set to `False` (activity disabled), but non-async tools are run in threads which are not supported outside of an activity. Make the tool function async instead." + ), + ): + await client.execute_workflow( # pyright: ignore[reportUnknownMemberType] + AgentWorkflowWithSyncToolActivityDisabled.run, + args=['What is the weather in Mexico City?'], + id=AgentWorkflowWithSyncToolActivityDisabled.__name__, + task_queue=TASK_QUEUE, + ) + + class CustomModelSettings(ModelSettings, total=False): custom_setting: str diff --git a/tests/test_vercel_ai.py b/tests/test_vercel_ai.py index 085cd38631..26b28b080a 100644 --- a/tests/test_vercel_ai.py +++ b/tests/test_vercel_ai.py @@ -25,6 +25,7 @@ PartEndEvent, PartStartEvent, RetryPromptPart, + Return, SystemPromptPart, TextPart, TextPartDelta, @@ -1964,3 +1965,82 @@ async def test_adapter_load_messages(): ), ] ) + + +async def test_custom_events(): + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + if len(messages) == 1: + yield { + 0: DeltaToolCall( + name='get_weather', + json_args='{"city":', + tool_call_id='get_weather_1', + ) + } + yield { + 0: DeltaToolCall( + json_args='"Mexico City"}', + tool_call_id='get_weather_1', + ) + } + else: + yield 'The weather in Mexico City is sunny.' + + agent = Agent(model=FunctionModel(stream_function=stream_function)) + + @agent.tool_plain + async def get_weather(city: str) -> AsyncIterator[BaseChunk | Return[str]]: + yield DataChunk(type='data-progress', data='Getting weather...') + yield Return('sunny') + + request = SubmitMessage( + id='foo', + messages=[ + UIMessage( + id='bar', + role='user', + parts=[TextUIPart(text='What is the weather in Mexico City?')], + ), + ], + ) + adapter = VercelAIAdapter(agent, request) + events = [ + '[DONE]' if '[DONE]' in event else json.loads(event.removeprefix('data: ')) + async for event in adapter.encode_stream(adapter.run_stream()) + ] + + assert events == snapshot( + [ + {'type': 'start'}, + {'type': 'start-step'}, + {'type': 'tool-input-start', 'toolCallId': 'get_weather_1', 'toolName': 'get_weather'}, + {'type': 'tool-input-delta', 'toolCallId': 'get_weather_1', 'inputTextDelta': '{"city":'}, + {'type': 'tool-input-delta', 'toolCallId': 'get_weather_1', 'inputTextDelta': '"Mexico City"}'}, + { + 'type': 'tool-input-available', + 'toolCallId': 'get_weather_1', + 'toolName': 'get_weather', + 'input': '{"city":"Mexico City"}', + }, + {'type': 'data-progress', 'data': 'Getting weather...'}, + { + 'type': 'tool-output-available', + 'toolCallId': 'get_weather_1', + 'output': 'sunny', + }, + {'type': 'finish-step'}, + {'type': 'start-step'}, + {'type': 'text-start', 'id': '2ec4f1af-c0be-44a3-9d50-2244deb4716d'}, + { + 'type': 'text-delta', + 'delta': 'The weather in Mexico City is sunny.', + 'id': '2ec4f1af-c0be-44a3-9d50-2244deb4716d', + }, + {'type': 'text-end', 'id': '2ec4f1af-c0be-44a3-9d50-2244deb4716d'}, + {'type': 'finish-step'}, + {'type': 'finish'}, + '[DONE]', + ] + ) diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 83c8f7bc3f..1d2acdf343 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -2,21 +2,25 @@ # pyright: reportUnnecessaryTypeIgnoreComment=false import re -from collections.abc import Awaitable, Callable +from collections.abc import AsyncIterator, Awaitable, Callable from dataclasses import dataclass from decimal import Decimal from typing import Any, TypeAlias from typing_extensions import assert_type -from pydantic_ai import Agent, ModelRetry, RunContext, Tool +from pydantic_ai import Agent, AgentRunResultEvent, AgentStreamEvent, ModelRetry, Return, RunContext, Tool from pydantic_ai.agent import AgentRunResult +from pydantic_ai.messages import CustomEvent from pydantic_ai.output import StructuredDict, TextOutput, ToolOutput from pydantic_ai.tools import DeferredToolRequests, ToolDefinition # Define here so we can check `if MYPY` below. This will not be executed, MYPY will always set it to True MYPY = False +simple_agent = Agent() +assert_type(simple_agent, Agent[None, str, object]) + @dataclass class MyDeps: @@ -310,3 +314,64 @@ async def prepare_greet(ctx: RunContext[str], tool_def: ToolDefinition) -> ToolD partial_agent: Agent[MyDeps] = Agent(deps_type=MyDeps) assert_type(partial_agent, Agent[MyDeps, str]) assert_type(partial_agent, Agent[MyDeps]) + + +async def custom_str_events() -> AsyncIterator[str | Return[float]]: + yield 'Getting temperature...' + yield Return(28.7) + + +async def custom_decimal_events() -> AsyncIterator[CustomEvent[Decimal] | Return[float]]: + yield CustomEvent(data=Decimal(28.0)) + yield CustomEvent(data=Decimal(0.7)) + yield Return(28.7) + + +async def custom_int_events() -> AsyncIterator[int | Return[float]]: + yield 1 + yield 2 + yield 3 + yield Return(28.7) + + +custom_str_events_tool = Tool(custom_str_events) +assert_type(custom_str_events_tool, Tool[object, str]) + +custom_decimal_events_tool = Tool(custom_decimal_events) +assert_type(custom_decimal_events_tool, Tool[object, Decimal]) + +custom_int_events_tool = Tool(custom_int_events) +assert_type(custom_int_events_tool, Tool[object, int]) + +custom_str_event_agent = Agent(tools=[custom_str_events]) +assert_type(custom_str_event_agent, Agent[None, str, str]) + + +# TODO (DouweM): Is valid, but shouldn't be (Decimal != str), but return type +# `AsyncIterator[CustomEvent[Decimal] | Return[float]]` is matched to +# Requires stream=True; Error if AsyncIterator and no stream=True; overloads?`Any` +custom_str_event_agent.tool_plain(custom_decimal_events, stream=True) + +# --- + +custom_event_agent = Agent(tools=[custom_str_events_tool, custom_decimal_events], output_type=custom_int_events) +# TODO: Require stream or smth here as well? + +# TODO (DouweM): This infers `CustomEventDataT` as `str | decimal` because of `tools`, +# and then `OutputT` as `AsyncIterator[CustomEvent[int] | Return[float]]` because it matches `Any` +# Ideally `CustomEventDataT` would be inferred as `str | Decimal | int` and `OutputT` as `float` +assert_type(custom_event_agent, Agent[None, float, str | int | Decimal]) + +event_stream = custom_event_agent.run_stream_events() +assert_type(event_stream, AsyncIterator[AgentStreamEvent[str | int | Decimal] | AgentRunResultEvent[float]]) + +# --- + +custom_event_agent2 = Agent[None, float, str | int | Decimal]( + tools=[custom_str_events_tool, custom_decimal_events], + output_type=custom_int_events, +) +assert_type(custom_event_agent2, Agent[None, float, str | int | Decimal]) + +event_stream2 = custom_event_agent2.run_stream_events() +assert_type(event_stream2, AsyncIterator[AgentStreamEvent[str | int | Decimal] | AgentRunResultEvent[float]])