Skip to content
Empty file.
1 change: 1 addition & 0 deletions docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pip/uv-add "pydantic-ai-slim[openai]"
* `ag-ui` - installs [AG-UI Event Stream Protocol](ui/ag-ui.md) dependencies `ag-ui-protocol` [PyPI ↗](https://pypi.org/project/ag-ui-protocol){:target="_blank"} and `starlette` [PyPI ↗](https://pypi.org/project/starlette){:target="_blank"}
* `dbos` - installs [DBOS Durable Execution](durable_execution/dbos.md) dependency `dbos` [PyPI ↗](https://pypi.org/project/dbos){:target="_blank"}
* `prefect` - installs [Prefect Durable Execution](durable_execution/prefect.md) dependency `prefect` [PyPI ↗](https://pypi.org/project/prefect){:target="_blank"}
* `restate` - installs [`restate`](durable_execution/restate.md) [PyPI ↗](https://pypi.org/project/restate-sdk){:target="_blank"}

You can also install dependencies for multiple models and use cases, for example:

Expand Down
11 changes: 11 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/restate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ._agent import RestateAgent
from ._model import RestateModelWrapper
from ._serde import PydanticTypeAdapter
from ._toolset import RestateContextRunToolSet

__all__ = [
'RestateModelWrapper',
'RestateAgent',
'PydanticTypeAdapter',
'RestateContextRunToolSet',
]
371 changes: 371 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/restate/_agent.py

Large diffs are not rendered by default.

99 changes: 99 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/restate/_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Any

from restate import Context, RunOptions

from pydantic_ai.agent.abstract import EventStreamHandler
from pydantic_ai.durable_exec.restate._serde import PydanticTypeAdapter
from pydantic_ai.exceptions import UserError
from pydantic_ai.messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
from pydantic_ai.models.wrapper import WrapperModel
from pydantic_ai.settings import ModelSettings
from pydantic_ai.tools import AgentDepsT, RunContext
from pydantic_ai.usage import RequestUsage

MODEL_RESPONSE_SERDE = PydanticTypeAdapter(ModelResponse)


class RestateStreamedResponse(StreamedResponse):
def __init__(self, model_request_parameters: ModelRequestParameters, response: ModelResponse):
super().__init__(model_request_parameters)
self.response = response

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
return
# noinspection PyUnreachableCode
yield

def get(self) -> ModelResponse:
return self.response

def usage(self) -> RequestUsage:
return self.response.usage # pragma: no cover

@property
def model_name(self) -> str:
return self.response.model_name or '' # pragma: no cover

@property
def provider_name(self) -> str:
return self.response.provider_name or '' # pragma: no cover

@property
def timestamp(self) -> datetime:
return self.response.timestamp # pragma: no cover


class RestateModelWrapper(WrapperModel):
def __init__(
self,
wrapped: Model,
context: Context,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
max_attempts: int | None = None,
):
super().__init__(wrapped)
self._options = RunOptions(serde=MODEL_RESPONSE_SERDE, max_attempts=max_attempts)
self._context = context
self._event_stream_handler = event_stream_handler

async def request(self, *args: Any, **kwargs: Any) -> ModelResponse:
return await self._context.run_typed('Model call', self.wrapped.request, self._options, *args, **kwargs)

@asynccontextmanager
async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
run_context: RunContext[AgentDepsT] | None = None,
) -> AsyncIterator[StreamedResponse]:
if run_context is None:
raise UserError(
'A model cannot be used with `pydantic_ai.direct.model_request_stream()` as it requires a `run_context`. Set an `event_stream_handler` on the agent and use `agent.run()` instead.'
)

# We can never get here without an `event_stream_handler`, as `TemporalAgent.run_stream` and `TemporalAgent.iter` raise an error saying to use `TemporalAgent.run` instead,
# and that only calls `request_stream` if `event_stream_handler` is set.
fn = self._event_stream_handler
assert fn is not None

async def request_stream_run():
async with self.wrapped.request_stream(
messages,
model_settings,
model_request_parameters,
run_context,
) as streamed_response:
await fn(run_context, streamed_response)

async for _ in streamed_response:
pass
return streamed_response.get()

response = await self._context.run_typed('Model stream call', request_stream_run, self._options)

yield RestateStreamedResponse(model_request_parameters, response)
45 changes: 45 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/restate/_serde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import typing

from pydantic import TypeAdapter
from restate.serde import Serde

T = typing.TypeVar('T')


class PydanticTypeAdapter(Serde[T]):
"""A serializer/deserializer for Pydantic models."""

def __init__(self, model_type: type[T]):
"""Initializes a new instance of the PydanticTypeAdaptorSerde class.

Args:
model_type (typing.Type[T]): The Pydantic model type to serialize/deserialize.
"""
self._model_type = TypeAdapter(model_type)

def deserialize(self, buf: bytes) -> T | None:
"""Deserializes a bytearray to a Pydantic model.

Args:
buf (bytearray): The bytearray to deserialize.

Returns:
typing.Optional[T]: The deserialized Pydantic model.
"""
if not buf:
return None
return self._model_type.validate_json(buf.decode('utf-8')) # raises if invalid

def serialize(self, obj: T | None) -> bytes:
"""Serializes a Pydantic model to a bytearray.

Args:
obj (typing.Optional[T]): The Pydantic model to serialize.

Returns:
bytes: The serialized bytearray.
"""
if obj is None:
return b''
tpe = TypeAdapter(type(obj))
return tpe.dump_json(obj)
144 changes: 144 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/restate/_toolset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Literal

from restate import Context, RunOptions, TerminalError

from pydantic_ai import ToolDefinition
from pydantic_ai._run_context import AgentDepsT
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
from pydantic_ai.mcp import MCPServer, ToolResult
from pydantic_ai.tools import RunContext
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
from pydantic_ai.toolsets.wrapper import WrapperToolset

from ._serde import PydanticTypeAdapter


@dataclass
class RestateContextRunResult:
"""A simple wrapper for tool results to be used with Restate's run_typed."""

kind: Literal['output', 'call_deferred', 'approval_required', 'model_retry']
output: Any
error: str | None = None


CONTEXT_RUN_SERDE = PydanticTypeAdapter(RestateContextRunResult)


@dataclass
class RestateMCPGetToolsContextRunResult:
"""A simple wrapper for tool results to be used with Restate's run_typed."""

output: dict[str, ToolDefinition]


MCP_GET_TOOLS_SERDE = PydanticTypeAdapter(RestateMCPGetToolsContextRunResult)


@dataclass
class RestateMCPToolRunResult:
"""A simple wrapper for tool results to be used with Restate's run_typed."""

output: ToolResult


MCP_RUN_SERDE = PydanticTypeAdapter(RestateMCPToolRunResult)


class RestateContextRunToolSet(WrapperToolset[AgentDepsT]):
"""A toolset that automatically wraps tool calls with restate's `ctx.run_typed()`."""

def __init__(self, wrapped: AbstractToolset[AgentDepsT], context: Context):
super().__init__(wrapped)
self._context = context
self.options = RunOptions[RestateContextRunResult](serde=CONTEXT_RUN_SERDE)

async def call_tool(
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
) -> Any:
async def action() -> RestateContextRunResult:
try:
# A tool may raise ModelRetry, CallDeferred, ApprovalRequired, or UserError
# to signal special conditions to the caller.
# Since, restate ctx.run() will retry this exception we need to convert these exceptions
# to a return value and handle them outside of the ctx.run().
output = await self.wrapped.call_tool(name, tool_args, ctx, tool)
return RestateContextRunResult(kind='output', output=output, error=None)
except ModelRetry as e:
return RestateContextRunResult(kind='model_retry', output=None, error=e.message)
except CallDeferred:
return RestateContextRunResult(kind='call_deferred', output=None, error=None)
except ApprovalRequired:
return RestateContextRunResult(kind='approval_required', output=None, error=None)
except UserError as e:
raise TerminalError(str(e)) from e

res = await self._context.run_typed(f'Calling {name}', action, self.options)

if res.kind == 'call_deferred':
raise CallDeferred()
elif res.kind == 'approval_required':
raise ApprovalRequired()
elif res.kind == 'model_retry':
assert res.error is not None
raise ModelRetry(res.error)
else:
assert res.kind == 'output'
return res.output

def visit_and_replace(
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
) -> AbstractToolset[AgentDepsT]:
return visitor(self)


class RestateMCPServer(WrapperToolset[AgentDepsT]):
"""A wrapper for MCPServer that integrates with restate."""

def __init__(self, wrapped: MCPServer, context: Context):
super().__init__(wrapped)
self._wrapped = wrapped
self._context = context

def visit_and_replace(
self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]]
) -> AbstractToolset[AgentDepsT]:
return visitor(self)

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
async def get_tools_in_context() -> RestateMCPGetToolsContextRunResult:
res = await self._wrapped.get_tools(ctx)
# ToolsetTool is not serializable as it holds a SchemaValidator
# (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
return RestateMCPGetToolsContextRunResult(output={name: tool.tool_def for name, tool in res.items()})

options = RunOptions(serde=MCP_GET_TOOLS_SERDE)

tool_defs = await self._context.run_typed('get mcp tools', get_tools_in_context, options)

return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.output.items()}

def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
assert isinstance(self.wrapped, MCPServer)
return self.wrapped.tool_for_tool_def(tool_def)

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> ToolResult:
async def call_tool_in_context() -> RestateMCPToolRunResult:
res = await self._wrapped.call_tool(name, tool_args, ctx, tool)
return RestateMCPToolRunResult(output=res)

options = RunOptions(serde=MCP_RUN_SERDE)
res = await self._context.run_typed(f'Calling mcp tool {name}', call_tool_in_context, options)

return res.output
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ temporal = ["temporalio==1.18.0"]
dbos = ["dbos>=1.14.0"]
# Prefect
prefect = ["prefect>=3.4.21"]
# Restate
restate = ["restate_sdk[serde]==0.10"]

[tool.hatch.metadata]
allow-direct-references = true
Expand Down
Loading
Loading