Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 58 additions & 108 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
from __future__ import annotations

import warnings
from collections.abc import AsyncIterator, Callable, Sequence
from contextlib import AbstractAsyncContextManager
from dataclasses import replace
from typing import Any

from pydantic.errors import PydanticUserError
from temporalio.client import ClientConfig, Plugin as ClientPlugin, WorkflowHistory
from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter
from temporalio.converter import DataConverter, DefaultPayloadConverter
from temporalio.service import ConnectConfig, ServiceClient
from temporalio.worker import (
Plugin as WorkerPlugin,
Replayer,
ReplayerConfig,
Worker,
WorkerConfig,
WorkflowReplayResult,
)
from temporalio.plugin import SimplePlugin
from temporalio.worker import WorkflowRunner
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner

from ...exceptions import UserError
Expand Down Expand Up @@ -48,104 +38,64 @@
pass


class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
def _data_converter(converter: DataConverter | None) -> DataConverter:
if converter and converter.payload_converter_class not in (
DefaultPayloadConverter,
PydanticPayloadConverter,
):
warnings.warn( # pragma: no cover
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
)

return pydantic_data_converter


def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner:
if not runner:
raise ValueError('No WorkflowRunner provided to the Pydantic AI plugin.') # pragma: no cover

if not isinstance(runner, SandboxedWorkflowRunner):
return runner # pragma: no cover

return replace(
runner,
restrictions=runner.restrictions.with_passthrough_modules(
'pydantic_ai',
'pydantic',
'pydantic_core',
'logfire',
'rich',
'httpx',
'anyio',
'httpcore',
# Used by fastmcp via py-key-value-aio
'beartype',
# Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize
'attrs',
# Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize
'numpy',
'pandas',
),
)


class PydanticAIPlugin(SimplePlugin):
"""Temporal client and worker plugin for Pydantic AI."""

def init_client_plugin(self, next: ClientPlugin) -> None:
self.next_client_plugin = next

def init_worker_plugin(self, next: WorkerPlugin) -> None:
self.next_worker_plugin = next

def configure_client(self, config: ClientConfig) -> ClientConfig:
config['data_converter'] = self._get_new_data_converter(config.get('data_converter'))
return self.next_client_plugin.configure_client(config)

def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType]
if isinstance(runner, SandboxedWorkflowRunner): # pragma: no branch
config['workflow_runner'] = replace(
runner,
restrictions=runner.restrictions.with_passthrough_modules(
'pydantic_ai',
'pydantic',
'pydantic_core',
'logfire',
'rich',
'httpx',
'anyio',
'httpcore',
# Used by fastmcp via py-key-value-aio
'beartype',
# Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize
'attrs',
# Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize
'numpy',
'pandas',
),
)

config['workflow_failure_exception_types'] = [
*config.get('workflow_failure_exception_types', []), # pyright: ignore[reportUnknownMemberType]
UserError,
PydanticUserError,
]

return self.next_worker_plugin.configure_worker(config)

async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
return await self.next_client_plugin.connect_service_client(config)

async def run_worker(self, worker: Worker) -> None:
await self.next_worker_plugin.run_worker(worker)

def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover
config['data_converter'] = self._get_new_data_converter(config.get('data_converter')) # pyright: ignore[reportUnknownMemberType]
return self.next_worker_plugin.configure_replayer(config)

def run_replayer(
self,
replayer: Replayer,
histories: AsyncIterator[WorkflowHistory],
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover
return self.next_worker_plugin.run_replayer(replayer, histories)

def _get_new_data_converter(self, converter: DataConverter | None) -> DataConverter:
if converter and converter.payload_converter_class not in (
DefaultPayloadConverter,
PydanticPayloadConverter,
):
warnings.warn( # pragma: no cover
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
)

return pydantic_data_converter


class AgentPlugin(WorkerPlugin):
"""Temporal worker plugin for a specific Pydantic AI agent."""

def __init__(self, agent: TemporalAgent[Any, Any]):
self.agent = agent

def init_worker_plugin(self, next: WorkerPlugin) -> None:
self.next_worker_plugin = next
def __init__(self):
super().__init__( # type: ignore[reportUnknownMemberType]
name='PydanticAIPlugin',
data_converter=_data_converter,
workflow_runner=_workflow_runner,
workflow_failure_exception_types=[UserError, PydanticUserError],
)

def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType]
# Activities are checked for name conflicts by Temporal.
config['activities'] = [*activities, *self.agent.temporal_activities]
return self.next_worker_plugin.configure_worker(config)

async def run_worker(self, worker: Worker) -> None:
await self.next_worker_plugin.run_worker(worker)

def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover
return self.next_worker_plugin.configure_replayer(config)
class AgentPlugin(SimplePlugin):
"""Temporal worker plugin for a specific Pydantic AI agent."""

def run_replayer(
self,
replayer: Replayer,
histories: AsyncIterator[WorkflowHistory],
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover
return self.next_worker_plugin.run_replayer(replayer, histories)
def __init__(self, agent: TemporalAgent[Any, Any]):
super().__init__( # type: ignore[reportUnknownMemberType]
name='AgentPlugin',
activities=agent.temporal_activities,
)
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async def _call_event_stream_handler_activity(
) -> None:
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
async for event in stream:
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
await workflow.execute_activity(
activity=self.event_stream_handler_activity,
args=[
_EventStreamHandlerParams(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def call_tool(
tool_activity_config = self.activity_config | tool_activity_config
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
return self._unwrap_call_tool_result(
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
await workflow.execute_activity(
activity=self.call_tool_activity,
args=[
CallToolParams(
Expand Down
28 changes: 13 additions & 15 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_logfire.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from collections.abc import Callable
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING

from temporalio.client import ClientConfig, Plugin as ClientPlugin
from temporalio.plugin import SimplePlugin
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
from temporalio.service import ConnectConfig, ServiceClient

Expand All @@ -19,12 +19,14 @@ def _default_setup_logfire() -> Logfire:
return instance


class LogfirePlugin(ClientPlugin):
class LogfirePlugin(SimplePlugin):
"""Temporal client plugin for Logfire."""

def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire, *, metrics: bool = True):
try:
import logfire # noqa: F401 # pyright: ignore[reportUnusedImport]
from opentelemetry.trace import get_tracer
from temporalio.contrib.opentelemetry import TracingInterceptor
except ImportError as _import_error:
raise ImportError(
'Please install the `logfire` package to use the Logfire plugin, '
Expand All @@ -34,18 +36,14 @@ def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire
self.setup_logfire = setup_logfire
self.metrics = metrics

def init_client_plugin(self, next: ClientPlugin) -> None:
self.next_client_plugin = next
super().__init__( # type: ignore[reportUnknownMemberType]
name='LogfirePlugin',
client_interceptors=[TracingInterceptor(get_tracer('temporalio'))],
)

def configure_client(self, config: ClientConfig) -> ClientConfig:
from opentelemetry.trace import get_tracer
from temporalio.contrib.opentelemetry import TracingInterceptor

interceptors = config.get('interceptors', [])
config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))]
return self.next_client_plugin.configure_client(config)

async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
async def connect_service_client(
self, config: ConnectConfig, next: Callable[[ConnectConfig], Awaitable[ServiceClient]]
) -> ServiceClient:
logfire = self.setup_logfire()

if self.metrics:
Expand All @@ -60,4 +58,4 @@ async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url, headers=headers))
)

return await self.next_client_plugin.connect_service_client(config)
return await next(config)
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
return await super().get_tools(ctx)

serialized_run_context = self.run_context_type.serialize_run_context(ctx)
tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
tool_defs = await workflow.execute_activity(
activity=self.get_tools_activity,
args=[
_GetToolsParams(serialized_run_context=serialized_run_context),
Expand All @@ -131,7 +131,7 @@ async def call_tool(
tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
return self._unwrap_call_tool_result(
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
await workflow.execute_activity(
activity=self.call_tool_activity,
args=[
CallToolParams(
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def request(

self._validate_model_request_parameters(model_request_parameters)

return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
return await workflow.execute_activity(
activity=self.request_activity,
arg=_RequestParams(
messages=messages,
Expand Down Expand Up @@ -168,7 +168,7 @@ async def request_stream(
self._validate_model_request_parameters(model_request_parameters)

serialized_run_context = self.run_context_type.serialize_run_context(run_context)
response = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
response = await workflow.execute_activity(
activity=self.request_stream_activity,
args=[
_RequestParams(
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"]
# Retries
retries = ["tenacity>=8.2.3"]
# Temporal
temporal = ["temporalio==1.18.2"]
temporal = ["temporalio==1.19.0"]
# DBOS
dbos = ["dbos>=1.14.0"]
# Prefect
Expand Down
Loading