diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 9eaf63841..50b121575 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -20,6 +20,7 @@ from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle +from ..experimental.hooks import AgentInitializedEvent, EndRequestEvent, HookRegistry, StartRequestEvent from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..handlers.tool_handler import AgentToolHandler from ..models.bedrock import BedrockModel @@ -308,6 +309,10 @@ def __init__( self.name = name self.description = description + self._hooks = HookRegistry() + # Register built-in hook providers (like ConversationManager) here + self._hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) + @property def tool(self) -> ToolCaller: """Call tool as a function. @@ -405,21 +410,26 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None) that the agent will use when responding. prompt: The prompt to use for the agent. """ - messages = self.messages - if not messages and not prompt: - raise ValueError("No conversation history or prompt provided") + self._hooks.invoke_callbacks(StartRequestEvent(agent=self)) - # add the prompt as the last message - if prompt: - messages.append({"role": "user", "content": [{"text": prompt}]}) + try: + messages = self.messages + if not messages and not prompt: + raise ValueError("No conversation history or prompt provided") - # get the structured output from the model - events = self.model.structured_output(output_model, messages) - for event in events: - if "callback" in event: - self.callback_handler(**cast(dict, event["callback"])) + # add the prompt as the last message + if prompt: + messages.append({"role": "user", "content": [{"text": prompt}]}) - return event["output"] + # get the structured output from the model + events = self.model.structured_output(output_model, messages) + for event in events: + if "callback" in event: + self.callback_handler(**cast(dict, event["callback"])) + + return event["output"] + finally: + self._hooks.invoke_callbacks(EndRequestEvent(agent=self)) async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -473,6 +483,8 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]: """Execute the agent's event loop with the given prompt and parameters.""" + self._hooks.invoke_callbacks(StartRequestEvent(agent=self)) + try: # Extract key parameters yield {"callback": {"init_event_loop": True, **kwargs}} @@ -487,6 +499,7 @@ def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str, finally: self.conversation_manager.apply_management(self) + self._hooks.invoke_callbacks(EndRequestEvent(agent=self)) def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[str, Any], None, None]: """Execute the event loop cycle with retry logic for context window limits. diff --git a/src/strands/experimental/__init__.py b/src/strands/experimental/__init__.py new file mode 100644 index 000000000..c40d0fcec --- /dev/null +++ b/src/strands/experimental/__init__.py @@ -0,0 +1,4 @@ +"""Experimental features. + +This module implements experimental features that are subject to change in future revisions without notice. +""" diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py new file mode 100644 index 000000000..3ec805137 --- /dev/null +++ b/src/strands/experimental/hooks/__init__.py @@ -0,0 +1,43 @@ +"""Typed hook system for extending agent functionality. + +This module provides a composable mechanism for building objects that can hook +into specific events during the agent lifecycle. The hook system enables both +built-in SDK components and user code to react to or modify agent behavior +through strongly-typed event callbacks. + +Example Usage: + ```python + from strands.hooks import HookProvider, HookRegistry + from strands.hooks.events import StartRequestEvent, EndRequestEvent + + class LoggingHooks(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(StartRequestEvent, self.log_start) + registry.add_callback(EndRequestEvent, self.log_end) + + def log_start(self, event: StartRequestEvent) -> None: + print(f"Request started for {event.agent.name}") + + def log_end(self, event: EndRequestEvent) -> None: + print(f"Request completed for {event.agent.name}") + + # Use with agent + agent = Agent(hooks=[LoggingHooks()]) + ``` + +This replaces the older callback_handler approach with a more composable, +type-safe system that supports multiple subscribers per event type. +""" + +from .events import AgentInitializedEvent, EndRequestEvent, StartRequestEvent +from .registry import HookCallback, HookEvent, HookProvider, HookRegistry + +__all__ = [ + "AgentInitializedEvent", + "StartRequestEvent", + "EndRequestEvent", + "HookEvent", + "HookProvider", + "HookCallback", + "HookRegistry", +] diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py new file mode 100644 index 000000000..c42b82d54 --- /dev/null +++ b/src/strands/experimental/hooks/events.py @@ -0,0 +1,64 @@ +"""Hook events emitted as part of invoking Agents. + +This module defines the events that are emitted as Agents run through the lifecycle of a request. +""" + +from dataclasses import dataclass + +from .registry import HookEvent + + +@dataclass +class AgentInitializedEvent(HookEvent): + """Event triggered when an agent has finished initialization. + + This event is fired after the agent has been fully constructed and all + built-in components have been initialized. Hook providers can use this + event to perform setup tasks that require a fully initialized agent. + """ + + pass + + +@dataclass +class StartRequestEvent(HookEvent): + """Event triggered at the beginning of a new agent request. + + This event is fired when the agent begins processing a new user request, + before any model inference or tool execution occurs. Hook providers can + use this event to perform request-level setup, logging, or validation. + + This event is triggered at the beginning of the following api calls: + - Agent.__call__ + - Agent.stream_async + - Agent.structured_output + """ + + pass + + +@dataclass +class EndRequestEvent(HookEvent): + """Event triggered at the end of an agent request. + + This event is fired after the agent has completed processing a request, + regardless of whether it completed successfully or encountered an error. + Hook providers can use this event for cleanup, logging, or state persistence. + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + This event is triggered at the end of the following api calls: + - Agent.__call__ + - Agent.stream_async + - Agent.structured_output + """ + + @property + def should_reverse_callbacks(self) -> bool: + """Return True to invoke callbacks in reverse order for proper cleanup. + + Returns: + True, indicating callbacks should be invoked in reverse order. + """ + return True diff --git a/src/strands/experimental/hooks/registry.py b/src/strands/experimental/hooks/registry.py new file mode 100644 index 000000000..4b3eceb4b --- /dev/null +++ b/src/strands/experimental/hooks/registry.py @@ -0,0 +1,195 @@ +"""Hook registry system for managing event callbacks in the Strands Agent SDK. + +This module provides the core infrastructure for the typed hook system, enabling +composable extension of agent functionality through strongly-typed event callbacks. +The registry manages the mapping between event types and their associated callback +functions, supporting both individual callback registration and bulk registration +via hook provider objects. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Generator, Generic, Protocol, Type, TypeVar + +if TYPE_CHECKING: + from ...agent import Agent + + +@dataclass +class HookEvent: + """Base class for all hook events. + + Attributes: + agent: The agent instance that triggered this event. + """ + + agent: "Agent" + + @property + def should_reverse_callbacks(self) -> bool: + """Determine if callbacks for this event should be invoked in reverse order. + + Returns: + False by default. Override to return True for events that should + invoke callbacks in reverse order (e.g., cleanup/teardown events). + """ + return False + + +T = TypeVar("T", bound=Callable) +TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True) + + +class HookProvider(Protocol): + """Protocol for objects that provide hook callbacks to an agent. + + Hook providers offer a composable way to extend agent functionality by + subscribing to various events in the agent lifecycle. This protocol enables + building reusable components that can hook into agent events. + + Example: + ```python + class MyHookProvider(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + hooks.add_callback(StartRequestEvent, self.on_request_start) + hooks.add_callback(EndRequestEvent, self.on_request_end) + + agent = Agent(hooks=[MyHookProvider()]) + ``` + """ + + def register_hooks(self, registry: "HookRegistry") -> None: + """Register callback functions for specific event types. + + Args: + registry: The hook registry to register callbacks with. + """ + ... + + +class HookCallback(Protocol, Generic[TEvent]): + """Protocol for callback functions that handle hook events. + + Hook callbacks are functions that receive a single strongly-typed event + argument and perform some action in response. They should not return + values and any exceptions they raise will propagate to the caller. + + Example: + ```python + def my_callback(event: StartRequestEvent) -> None: + print(f"Request started for agent: {event.agent.name}") + ``` + """ + + def __call__(self, event: TEvent) -> None: + """Handle a hook event. + + Args: + event: The strongly-typed event to handle. + """ + ... + + +class HookRegistry: + """Registry for managing hook callbacks associated with event types. + + The HookRegistry maintains a mapping of event types to callback functions + and provides methods for registering callbacks and invoking them when + events occur. + + The registry handles callback ordering, including reverse ordering for + cleanup events, and provides type-safe event dispatching. + """ + + def __init__(self) -> None: + """Initialize an empty hook registry.""" + self._registered_callbacks: dict[Type, list[HookCallback]] = {} + + def add_callback(self, event_type: Type[TEvent], callback: HookCallback[TEvent]) -> None: + """Register a callback function for a specific event type. + + Args: + event_type: The class type of events this callback should handle. + callback: The callback function to invoke when events of this type occur. + + Example: + ```python + def my_handler(event: StartRequestEvent): + print("Request started") + + registry.add_callback(StartRequestEvent, my_handler) + ``` + """ + callbacks = self._registered_callbacks.setdefault(event_type, []) + callbacks.append(callback) + + def add_hook(self, hook: HookProvider) -> None: + """Register all callbacks from a hook provider. + + This method allows bulk registration of callbacks by delegating to + the hook provider's register_hooks method. This is the preferred + way to register multiple related callbacks. + + Args: + hook: The hook provider containing callbacks to register. + + Example: + ```python + class MyHooks(HookProvider): + def register_hooks(self, registry: HookRegistry): + registry.add_callback(StartRequestEvent, self.on_start) + registry.add_callback(EndRequestEvent, self.on_end) + + registry.add_hook(MyHooks()) + ``` + """ + hook.register_hooks(self) + + def invoke_callbacks(self, event: TEvent) -> None: + """Invoke all registered callbacks for the given event. + + This method finds all callbacks registered for the event's type and + invokes them in the appropriate order. For events with is_after_callback=True, + callbacks are invoked in reverse registration order. + + Args: + event: The event to dispatch to registered callbacks. + + Raises: + Any exceptions raised by callback functions will propagate to the caller. + + Example: + ```python + event = StartRequestEvent(agent=my_agent) + registry.invoke_callbacks(event) + ``` + """ + for callback in self.get_callbacks_for(event): + callback(event) + + def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], None, None]: + """Get callbacks registered for the given event in the appropriate order. + + This method returns callbacks in registration order for normal events, + or reverse registration order for events that have is_after_callback=True. + This enables proper cleanup ordering for teardown events. + + Args: + event: The event to get callbacks for. + + Yields: + Callback functions registered for this event type, in the appropriate order. + + Example: + ```python + event = EndRequestEvent(agent=my_agent) + for callback in registry.get_callbacks_for(event): + callback(event) + ``` + """ + event_type = type(event) + + callbacks = self._registered_callbacks.get(event_type, []) + if event.should_reverse_callbacks: + yield from reversed(callbacks) + else: + yield from callbacks diff --git a/tests/fixtures/mock_hook_provider.py b/tests/fixtures/mock_hook_provider.py new file mode 100644 index 000000000..a21770a56 --- /dev/null +++ b/tests/fixtures/mock_hook_provider.py @@ -0,0 +1,16 @@ +from typing import Type + +from strands.experimental.hooks import HookEvent, HookProvider, HookRegistry + + +class MockHookProvider(HookProvider): + def __init__(self, event_types: list[Type]): + self.events_received = [] + self.events_types = event_types + + def register_hooks(self, registry: HookRegistry) -> None: + for event_type in self.events_types: + registry.add_callback(event_type, self._add_event) + + def _add_event(self, event: HookEvent) -> None: + self.events_received.append(event) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 0c644b044..7681194c7 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -4,6 +4,7 @@ import os import textwrap import unittest.mock +from unittest.mock import call import pytest from pydantic import BaseModel @@ -13,10 +14,12 @@ from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager +from strands.experimental.hooks import AgentInitializedEvent, EndRequestEvent, StartRequestEvent from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException, EventLoopException +from tests.fixtures.mock_hook_provider import MockHookProvider @pytest.fixture @@ -37,6 +40,34 @@ def converse(*args, **kwargs): return mock +@pytest.fixture +def mock_hook_messages(mock_model, tool): + """Fixture which returns a standard set of events for verifying hooks.""" + mock_model.mock_converse.side_effect = [ + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": tool.tool_spec["name"], + }, + }, + }, + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "abcdEfghI123"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ], + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ], + ] + + return mock_model.mock_converse + + @pytest.fixture def system_prompt(request): return request.param if hasattr(request, "param") else "You are a helpful assistant." @@ -131,6 +162,11 @@ def tools(request, tool): return request.param if hasattr(request, "param") else [tool_decorated] +@pytest.fixture +def hook_provider(): + return MockHookProvider([AgentInitializedEvent, StartRequestEvent, EndRequestEvent]) + + @pytest.fixture def agent( mock_model, @@ -142,6 +178,7 @@ def agent( tool_registry, tool_decorated, request, + hook_provider, ): agent = Agent( model=mock_model, @@ -151,6 +188,9 @@ def agent( tools=tools, ) + # for now, hooks are private + agent._hooks.add_hook(hook_provider) + # Only register the tool directly if tools wasn't parameterized if not hasattr(request, "param") or request.param is None: # Create a new function tool directly from the decorated function @@ -683,6 +723,48 @@ def test_agent__call__callback(mock_model, agent, callback_handler): ) +@unittest.mock.patch("strands.experimental.hooks.registry.HookRegistry.invoke_callbacks") +def test_agent_hooks__init__(mock_invoke_callbacks): + """Verify that the AgentInitializedEvent is emitted on Agent construction.""" + agent = Agent() + + # Verify AgentInitialized event was invoked + mock_invoke_callbacks.assert_called_once() + assert mock_invoke_callbacks.call_args == call(AgentInitializedEvent(agent=agent)) + + +def test_agent_hooks__call__(agent, mock_hook_messages, hook_provider): + """Verify that the correct hook events are emitted as part of __call__.""" + + agent("test message") + + assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)] + + +@pytest.mark.asyncio +async def test_agent_hooks_stream_async(agent, mock_hook_messages, hook_provider): + """Verify that the correct hook events are emitted as part of stream_async.""" + iterator = agent.stream_async("test message") + await anext(iterator) + assert hook_provider.events_received == [StartRequestEvent(agent=agent)] + + # iterate the rest + async for _ in iterator: + pass + + assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)] + + +def test_agent_hooks_structured_output(agent, mock_hook_messages, hook_provider): + """Verify that the correct hook events are emitted as part of structured_output.""" + + expected_user = User(name="Jane Doe", age=30, email="jane@doe.com") + agent.model.structured_output = unittest.mock.Mock(return_value=[{"output": expected_user}]) + agent.structured_output(User, "example prompt") + + assert hook_provider.events_received == [StartRequestEvent(agent=agent), EndRequestEvent(agent=agent)] + + def test_agent_tool(mock_randint, agent): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy diff --git a/tests/strands/experimental/__init__.py b/tests/strands/experimental/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/hooks/__init__.py b/tests/strands/experimental/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/hooks/test_hook_registry.py b/tests/strands/experimental/hooks/test_hook_registry.py new file mode 100644 index 000000000..0bed07add --- /dev/null +++ b/tests/strands/experimental/hooks/test_hook_registry.py @@ -0,0 +1,152 @@ +import unittest.mock +from dataclasses import dataclass +from typing import List +from unittest.mock import MagicMock, Mock + +import pytest + +from strands.experimental.hooks import HookEvent, HookProvider, HookRegistry + + +@dataclass +class TestEvent(HookEvent): + @property + def should_reverse_callbacks(self) -> bool: + return False + + +@dataclass +class TestAfterEvent(HookEvent): + @property + def should_reverse_callbacks(self) -> bool: + return True + + +class TestHookProvider(HookProvider): + """Test hook provider for testing hook registry.""" + + def __init__(self): + self.registered = False + + def register_hooks(self, registry: HookRegistry) -> None: + self.registered = True + + +@pytest.fixture +def hook_registry(): + return HookRegistry() + + +@pytest.fixture +def test_event(): + return TestEvent(agent=Mock()) + + +@pytest.fixture +def test_after_event(): + return TestAfterEvent(agent=Mock()) + + +def test_hook_registry_init(): + """Test that HookRegistry initializes with an empty callbacks dictionary.""" + registry = HookRegistry() + assert registry._registered_callbacks == {} + + +def test_add_callback(hook_registry, test_event): + """Test that callbacks can be added to the registry.""" + callback = unittest.mock.Mock() + hook_registry.add_callback(TestEvent, callback) + + assert TestEvent in hook_registry._registered_callbacks + assert callback in hook_registry._registered_callbacks[TestEvent] + + +def test_add_multiple_callbacks_same_event(hook_registry, test_event): + """Test that multiple callbacks can be added for the same event type.""" + callback1 = unittest.mock.Mock() + callback2 = unittest.mock.Mock() + + hook_registry.add_callback(TestEvent, callback1) + hook_registry.add_callback(TestEvent, callback2) + + assert len(hook_registry._registered_callbacks[TestEvent]) == 2 + assert callback1 in hook_registry._registered_callbacks[TestEvent] + assert callback2 in hook_registry._registered_callbacks[TestEvent] + + +def test_add_hook(hook_registry): + """Test that hooks can be added to the registry.""" + hook_provider = MagicMock() + hook_registry.add_hook(hook_provider) + + assert hook_provider.register_hooks.call_count == 1 + + +def test_get_callbacks_for_normal_event(hook_registry, test_event): + """Test that get_callbacks_for returns callbacks in the correct order for normal events.""" + callback1 = unittest.mock.Mock() + callback2 = unittest.mock.Mock() + + hook_registry.add_callback(TestEvent, callback1) + hook_registry.add_callback(TestEvent, callback2) + + callbacks = list(hook_registry.get_callbacks_for(test_event)) + + assert len(callbacks) == 2 + assert callbacks[0] == callback1 + assert callbacks[1] == callback2 + + +def test_get_callbacks_for_after_event(hook_registry, test_after_event): + """Test that get_callbacks_for returns callbacks in reverse order for after events.""" + callback1 = Mock() + callback2 = Mock() + + hook_registry.add_callback(TestAfterEvent, callback1) + hook_registry.add_callback(TestAfterEvent, callback2) + + callbacks = list(hook_registry.get_callbacks_for(test_after_event)) + + assert len(callbacks) == 2 + assert callbacks[0] == callback2 # Reverse order + assert callbacks[1] == callback1 # Reverse order + + +def test_invoke_callbacks(hook_registry, test_event): + """Test that invoke_callbacks calls all registered callbacks for an event.""" + callback1 = Mock() + callback2 = Mock() + + hook_registry.add_callback(TestEvent, callback1) + hook_registry.add_callback(TestEvent, callback2) + + hook_registry.invoke_callbacks(test_event) + + callback1.assert_called_once_with(test_event) + callback2.assert_called_once_with(test_event) + + +def test_invoke_callbacks_no_registered_callbacks(hook_registry, test_event): + """Test that invoke_callbacks doesn't fail when there are no registered callbacks.""" + # No callbacks registered + hook_registry.invoke_callbacks(test_event) + # Test passes if no exception is raised + + +def test_invoke_callbacks_after_event(hook_registry, test_after_event): + """Test that invoke_callbacks calls callbacks in reverse order for after events.""" + call_order: List[str] = [] + + def callback1(_event): + call_order.append("callback1") + + def callback2(_event): + call_order.append("callback2") + + hook_registry.add_callback(TestAfterEvent, callback1) + hook_registry.add_callback(TestAfterEvent, callback2) + + hook_registry.invoke_callbacks(test_after_event) + + assert call_order == ["callback2", "callback1"] # Reverse order