Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import importlib.metadata

from agent_framework_durabletask import AgentCallbackContext, AgentResponseCallbackProtocol

from ._app import AgentFunctionApp
from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol
from ._orchestration import DurableAIAgent

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
import re
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, TypeVar, cast

import azure.durable_functions as df
import azure.functions as func
from agent_framework import AgentProtocol, get_logger

from ._callbacks import AgentResponseCallbackProtocol
from ._constants import (
from agent_framework_durabletask import (
DEFAULT_MAX_POLL_RETRIES,
DEFAULT_POLL_INTERVAL_SECONDS,
MIMETYPE_APPLICATION_JSON,
Expand All @@ -28,11 +27,14 @@
THREAD_ID_HEADER,
WAIT_FOR_RESPONSE_FIELD,
WAIT_FOR_RESPONSE_HEADER,
AgentResponseCallbackProtocol,
DurableAgentState,
RunRequest,
)
from ._durable_agent_state import DurableAgentState

from ._entities import create_agent_entity
from ._errors import IncomingRequestError
from ._models import AgentSessionId, RunRequest
from ._models import AgentSessionId
from ._orchestration import AgentOrchestrationContextType, DurableAIAgent

logger = get_logger("agent_framework.azurefunctions")
Expand Down Expand Up @@ -850,15 +852,19 @@ def _build_request_data(
enable_tool_calls_value = req_body.get("enable_tool_calls")
enable_tool_calls = True if enable_tool_calls_value is None else self._coerce_to_bool(enable_tool_calls_value)

return RunRequest(
message=message,
role=req_body.get("role"),
request_response_format=request_response_format,
response_format=req_body.get("response_format"),
enable_tool_calls=enable_tool_calls,
thread_id=thread_id,
correlation_id=correlation_id,
).to_dict()
return cast(
dict[str, Any],
RunRequest(
message=message,
role=req_body.get("role"),
request_response_format=request_response_format,
response_format=req_body.get("response_format"),
enable_tool_calls=enable_tool_calls,
thread_id=thread_id,
correlation_id=correlation_id,
created_at=datetime.utcnow(),
).to_dict(),
)

def _build_accepted_response(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]:
"""Build the response returned when not waiting for completion."""
Expand Down Expand Up @@ -991,8 +997,8 @@ def _accepts_json_response(headers: dict[str, str]) -> bool:
def _select_request_response_format(body_format: str, prefers_json: bool) -> str:
"""Combine body format and accept preference to determine response format."""
if body_format == REQUEST_RESPONSE_FORMAT_JSON or prefers_json:
return REQUEST_RESPONSE_FORMAT_JSON
return REQUEST_RESPONSE_FORMAT_TEXT
return str(REQUEST_RESPONSE_FORMAT_JSON)
return str(REQUEST_RESPONSE_FORMAT_TEXT)

@staticmethod
def _parse_json_body(req: func.HttpRequest) -> tuple[dict[str, Any], str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@
Role,
get_logger,
)

from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol
from ._durable_agent_state import (
from agent_framework_durabletask import (
AgentCallbackContext,
AgentResponseCallbackProtocol,
DurableAgentState,
DurableAgentStateData,
DurableAgentStateEntry,
DurableAgentStateRequest,
DurableAgentStateResponse,
RunRequest,
)
from ._models import RunRequest

logger = get_logger("agent_framework.azurefunctions.entities")

Expand Down Expand Up @@ -87,7 +87,7 @@ def _is_error_response(self, entry: DurableAgentStateEntry) -> bool:
True if the entry is a response containing error content, False otherwise
"""
if isinstance(entry, DurableAgentStateResponse):
return entry.is_error
return bool(entry.is_error)
return False

async def run_agent(
Expand Down Expand Up @@ -121,11 +121,11 @@ async def run_agent(
response_format = run_request.response_format
enable_tool_calls = run_request.enable_tool_calls

logger.debug(f"[AgentEntity.run_agent] Received Message: {run_request}")

state_request = DurableAgentStateRequest.from_run_request(run_request)
self.state.data.conversation_history.append(state_request)

logger.debug(f"[AgentEntity.run_agent] Received Message: {state_request}")

try:
# Build messages from conversation history, excluding error responses
# Error responses are kept in history for tracking but not sent to the agent
Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,23 @@
# Copyright (c) Microsoft. All rights reserved.

"""Data models for Durable Agent Framework.
"""Azure Functions-specific data models for Durable Agent Framework.

This module defines the request and response models used by the framework.
This module contains Azure Functions-specific models:
- AgentSessionId: Entity ID management for Azure Durable Entities
- DurableAgentThread: Thread implementation that tracks AgentSessionId

Common models like RunRequest have been moved to agent-framework-durabletask.
"""

from __future__ import annotations

import inspect
import uuid
from collections.abc import MutableMapping
from dataclasses import dataclass
from importlib import import_module
from typing import TYPE_CHECKING, Any, cast
from typing import Any

import azure.durable_functions as df
from agent_framework import AgentThread, Role

from ._constants import REQUEST_RESPONSE_FORMAT_TEXT

if TYPE_CHECKING: # pragma: no cover - type checking imports only
from pydantic import BaseModel

_PydanticBaseModel: type[BaseModel] | None

try:
from pydantic import BaseModel as _RuntimeBaseModel
except ImportError: # pragma: no cover - optional dependency
_PydanticBaseModel = None
else:
_PydanticBaseModel = _RuntimeBaseModel
from agent_framework import AgentThread


@dataclass
Expand Down Expand Up @@ -211,161 +199,3 @@ async def deserialize(

thread.attach_session(AgentSessionId.parse(session_id_value))
return thread


def serialize_response_format(response_format: type[BaseModel] | None) -> Any:
"""Serialize response format for transport across durable function boundaries."""
if response_format is None:
return None

if _PydanticBaseModel is None:
raise RuntimeError("pydantic is required to use structured response formats")

if not inspect.isclass(response_format) or not issubclass(response_format, _PydanticBaseModel):
raise TypeError("response_format must be a Pydantic BaseModel type")

return {
"__response_schema_type__": "pydantic_model",
"module": response_format.__module__,
"qualname": response_format.__qualname__,
}


def _deserialize_response_format(response_format: Any) -> type[BaseModel] | None:
"""Deserialize response format back into actionable type if possible."""
if response_format is None:
return None

if (
_PydanticBaseModel is not None
and inspect.isclass(response_format)
and issubclass(response_format, _PydanticBaseModel)
):
return response_format

if not isinstance(response_format, dict):
return None

response_dict = cast(dict[str, Any], response_format)

if response_dict.get("__response_schema_type__") != "pydantic_model":
return None

module_name = response_dict.get("module")
qualname = response_dict.get("qualname")
if not module_name or not qualname:
return None

try:
module = import_module(module_name)
except ImportError: # pragma: no cover - user provided module missing
return None

attr: Any = module
for part in qualname.split("."):
try:
attr = getattr(attr, part)
except AttributeError: # pragma: no cover - invalid qualname
return None

if _PydanticBaseModel is not None and inspect.isclass(attr) and issubclass(attr, _PydanticBaseModel):
return attr

return None


@dataclass
class RunRequest:
"""Represents a request to run an agent with a specific message and configuration.

Attributes:
message: The message to send to the agent
request_response_format: The desired response format (e.g., "text" or "json")
role: The role of the message sender (user, system, or assistant)
response_format: Optional Pydantic BaseModel type describing the structured response format
enable_tool_calls: Whether to enable tool calls for this request
thread_id: Optional thread ID for tracking
correlation_id: Optional correlation ID for tracking the response to this specific request
created_at: Optional timestamp when the request was created
orchestration_id: Optional ID of the orchestration that initiated this request
"""

message: str
request_response_format: str
role: Role = Role.USER
response_format: type[BaseModel] | None = None
enable_tool_calls: bool = True
thread_id: str | None = None
correlation_id: str | None = None
created_at: str | None = None
orchestration_id: str | None = None

def __init__(
self,
message: str,
request_response_format: str = REQUEST_RESPONSE_FORMAT_TEXT,
role: Role | str | None = Role.USER,
response_format: type[BaseModel] | None = None,
enable_tool_calls: bool = True,
thread_id: str | None = None,
correlation_id: str | None = None,
created_at: str | None = None,
orchestration_id: str | None = None,
) -> None:
self.message = message
self.role = self.coerce_role(role)
self.response_format = response_format
self.request_response_format = request_response_format
self.enable_tool_calls = enable_tool_calls
self.thread_id = thread_id
self.correlation_id = correlation_id
self.created_at = created_at
self.orchestration_id = orchestration_id

@staticmethod
def coerce_role(value: Role | str | None) -> Role:
"""Normalize various role representations into a Role instance."""
if isinstance(value, Role):
return value
if isinstance(value, str):
normalized = value.strip()
if not normalized:
return Role.USER
return Role(value=normalized.lower())
return Role.USER

def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result = {
"message": self.message,
"enable_tool_calls": self.enable_tool_calls,
"role": self.role.value,
"request_response_format": self.request_response_format,
}
if self.response_format:
result["response_format"] = serialize_response_format(self.response_format)
if self.thread_id:
result["thread_id"] = self.thread_id
if self.correlation_id:
result["correlationId"] = self.correlation_id
if self.created_at:
result["created_at"] = self.created_at
if self.orchestration_id:
result["orchestrationId"] = self.orchestration_id

return result

@classmethod
def from_dict(cls, data: dict[str, Any]) -> RunRequest:
"""Create RunRequest from dictionary."""
return cls(
message=data.get("message", ""),
request_response_format=data.get("request_response_format", REQUEST_RESPONSE_FORMAT_TEXT),
role=cls.coerce_role(data.get("role")),
response_format=_deserialize_response_format(data.get("response_format")),
enable_tool_calls=data.get("enable_tool_calls", True),
thread_id=data.get("thread_id"),
correlation_id=data.get("correlationId"),
created_at=data.get("created_at"),
orchestration_id=data.get("orchestrationId"),
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
ChatMessage,
get_logger,
)
from agent_framework_durabletask import RunRequest
from azure.durable_functions.models import TaskBase
from azure.durable_functions.models.Task import CompoundTask, TaskState
from pydantic import BaseModel

from ._models import AgentSessionId, DurableAgentThread, RunRequest
from ._models import AgentSessionId, DurableAgentThread

logger = get_logger("agent_framework.azurefunctions.orchestration")

Expand Down
1 change: 1 addition & 0 deletions python/packages/azurefunctions/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ classifiers = [
]
dependencies = [
"agent-framework-core",
"agent-framework-durabletask",
"azure-functions",
"azure-functions-durable",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
"""

import pytest

from agent_framework_azurefunctions._constants import THREAD_ID_HEADER
from agent_framework_durabletask import THREAD_ID_HEADER

from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled

Expand Down
11 changes: 6 additions & 5 deletions python/packages/azurefunctions/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
import azure.functions as func
import pytest
from agent_framework import AgentRunResponse, ChatMessage, ErrorContent

from agent_framework_azurefunctions import AgentFunctionApp
from agent_framework_azurefunctions._app import WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER
from agent_framework_azurefunctions._constants import (
from agent_framework_durabletask import (
MIMETYPE_APPLICATION_JSON,
MIMETYPE_TEXT_PLAIN,
THREAD_ID_HEADER,
WAIT_FOR_RESPONSE_FIELD,
WAIT_FOR_RESPONSE_HEADER,
DurableAgentState,
)
from agent_framework_azurefunctions._durable_agent_state import DurableAgentState

from agent_framework_azurefunctions import AgentFunctionApp
from agent_framework_azurefunctions._entities import AgentEntity, create_agent_entity

TFunc = TypeVar("TFunc", bound=Callable[..., Any])
Expand Down
Loading
Loading