diff --git a/.gitignore b/.gitignore index 8b0fd989c..0b1375b50 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ repl_state .kiro uv.lock .audio_cache +CLAUDE.md diff --git a/src/strands/experimental/steering/__init__.py b/src/strands/experimental/steering/__init__.py index 4d0775873..c9e1a470d 100644 --- a/src/strands/experimental/steering/__init__.py +++ b/src/strands/experimental/steering/__init__.py @@ -23,7 +23,7 @@ LedgerBeforeToolCall, LedgerProvider, ) -from .core.action import Guide, Interrupt, Proceed, SteeringAction +from .core.action import Guide, Interrupt, ModelSteeringAction, Proceed, SteeringAction, ToolSteeringAction from .core.context import SteeringContextCallback, SteeringContextProvider from .core.handler import SteeringHandler @@ -32,6 +32,8 @@ __all__ = [ "SteeringAction", + "ToolSteeringAction", + "ModelSteeringAction", "Proceed", "Guide", "Interrupt", diff --git a/src/strands/experimental/steering/core/action.py b/src/strands/experimental/steering/core/action.py index 8b4ec141d..57b059e87 100644 --- a/src/strands/experimental/steering/core/action.py +++ b/src/strands/experimental/steering/core/action.py @@ -1,18 +1,18 @@ """SteeringAction types for steering evaluation results. -Defines structured outcomes from steering handlers that determine how tool calls +Defines structured outcomes from steering handlers that determine how agent actions should be handled. SteeringActions enable modular prompting by providing just-in-time feedback rather than front-loading all instructions in monolithic prompts. Flow: - SteeringHandler.steer() → SteeringAction → BeforeToolCallEvent handling - ↓ ↓ ↓ - Evaluate context Action type Tool execution modified + SteeringHandler.steer_*() → SteeringAction → Event handling + ↓ ↓ ↓ + Evaluate context Action type Execution modified SteeringAction types: - Proceed: Tool executes immediately (no intervention needed) - Guide: Tool cancelled, agent receives contextual feedback to explore alternatives - Interrupt: Tool execution paused for human input via interrupt system + Proceed: Allow execution to continue without intervention + Guide: Provide contextual guidance to redirect the agent + Interrupt: Pause execution for human input Extensibility: New action types can be added to the union. Always handle the default @@ -25,9 +25,9 @@ class Proceed(BaseModel): - """Allow tool to execute immediately without intervention. + """Allow execution to continue without intervention. - The tool call proceeds as planned. The reason provides context + The action proceeds as planned. The reason provides context for logging and debugging purposes. """ @@ -36,11 +36,11 @@ class Proceed(BaseModel): class Guide(BaseModel): - """Cancel tool and provide contextual feedback for agent to explore alternatives. + """Provide contextual guidance to redirect the agent. - The tool call is cancelled and the agent receives the reason as contextual - feedback to help them consider alternative approaches while maintaining - adaptive reasoning capabilities. + The agent receives the reason as contextual feedback to help guide + its behavior. The specific handling depends on the steering context + (e.g., tool call vs. model response). """ type: Literal["guide"] = "guide" @@ -48,18 +48,38 @@ class Guide(BaseModel): class Interrupt(BaseModel): - """Pause tool execution for human input via interrupt system. + """Pause execution for human input via interrupt system. - The tool call is paused and human input is requested through Strands' + Execution is paused and human input is requested through Strands' interrupt system. The human can approve or deny the operation, and their - decision determines whether the tool executes or is cancelled. + decision determines whether execution continues or is cancelled. """ type: Literal["interrupt"] = "interrupt" reason: str -# SteeringAction union - extensible for future action types +# Context-specific steering action types +ToolSteeringAction = Annotated[Proceed | Guide | Interrupt, Field(discriminator="type")] +"""Steering actions valid for tool steering (steer_before_tool). + +- Proceed: Allow tool execution to continue +- Guide: Cancel tool and provide feedback for alternative approaches +- Interrupt: Pause for human input before tool execution +""" + +ModelSteeringAction = Annotated[Proceed | Guide, Field(discriminator="type")] +"""Steering actions valid for model steering (steer_after_model). + +- Proceed: Accept model response without modification +- Guide: Discard model response and retry with guidance +""" + +# Generic SteeringAction union for backward compatibility # IMPORTANT: Always handle the default case when pattern matching # to maintain backward compatibility as new action types are added SteeringAction = Annotated[Proceed | Guide | Interrupt, Field(discriminator="type")] +"""Generic steering action type for backward compatibility. + +Use ToolSteeringAction or ModelSteeringAction for type-safe context-specific steering. +""" diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 4a0bcaa6a..30287242f 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -2,38 +2,48 @@ Provides modular prompting through contextual guidance that appears when relevant, rather than front-loading all instructions. Handlers integrate with the Strands hook -system to intercept tool calls and provide just-in-time feedback based on local context. +system to intercept actions and provide just-in-time feedback based on local context. Architecture: - BeforeToolCallEvent → Context Callbacks → Update steering_context → steer() → SteeringAction - ↓ ↓ ↓ ↓ ↓ - Hook triggered Populate context Handler evaluates Handler decides Action taken + Hook Event → Context Callbacks → Update steering_context → steer_*() → SteeringAction + ↓ ↓ ↓ ↓ ↓ + Hook triggered Populate context Handler evaluates Handler decides Action taken Lifecycle: 1. Context callbacks update handler's steering_context on hook events - 2. BeforeToolCallEvent triggers steering evaluation via steer() method - 3. Handler accesses self.steering_context for guidance decisions - 4. SteeringAction determines tool execution: Proceed/Guide/Interrupt + 2. BeforeToolCallEvent triggers steer_before_tool() for tool steering + 3. AfterModelCallEvent triggers steer_after_model() for model steering + 4. Handler accesses self.steering_context for guidance decisions + 5. SteeringAction determines execution flow Implementation: - Subclass SteeringHandler and implement steer() method. - Pass context_callbacks in constructor to register context update functions. + Subclass SteeringHandler and override steer_before_tool() and/or steer_after_model(). + Both methods have default implementations that return Proceed, so you only need to + override the methods you want to customize. + Pass context_providers in constructor to register context update functions. Each handler maintains isolated steering_context that persists across calls. -SteeringAction handling: +SteeringAction handling for steer_before_tool: Proceed: Tool executes immediately Guide: Tool cancelled, agent receives contextual feedback to explore alternatives Interrupt: Tool execution paused for human input via interrupt system + +SteeringAction handling for steer_after_model: + Proceed: Model response accepted without modification + Guide: Discard model response and retry (message is dropped, model is called again) + Interrupt: Model response handling paused for human input via interrupt system """ import logging -from abc import ABC, abstractmethod +from abc import ABC from typing import TYPE_CHECKING, Any -from ....hooks.events import BeforeToolCallEvent +from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent from ....hooks.registry import HookProvider, HookRegistry +from ....types.content import Message +from ....types.streaming import StopReason from ....types.tools import ToolUse -from .action import Guide, Interrupt, Proceed, SteeringAction +from .action import Guide, Interrupt, ModelSteeringAction, Proceed, SteeringAction, ToolSteeringAction from .context import SteeringContext, SteeringContextProvider if TYPE_CHECKING: @@ -73,24 +83,27 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: callback.event_type, lambda event, callback=callback: callback(event, self.steering_context) ) - # Register steering guidance - registry.add_callback(BeforeToolCallEvent, self._provide_steering_guidance) + # Register tool steering guidance + registry.add_callback(BeforeToolCallEvent, self._provide_tool_steering_guidance) + + # Register model steering guidance + registry.add_callback(AfterModelCallEvent, self._provide_model_steering_guidance) - async def _provide_steering_guidance(self, event: BeforeToolCallEvent) -> None: + async def _provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: """Provide steering guidance for tool call.""" tool_name = event.tool_use["name"] - logger.debug("tool_name=<%s> | providing steering guidance", tool_name) + logger.debug("tool_name=<%s> | providing tool steering guidance", tool_name) try: - action = await self.steer(event.agent, event.tool_use) + action = await self.steer_before_tool(agent=event.agent, tool_use=event.tool_use) except Exception as e: - logger.debug("tool_name=<%s>, error=<%s> | steering handler guidance failed", tool_name, e) + logger.debug("tool_name=<%s>, error=<%s> | tool steering handler guidance failed", tool_name, e) return - self._handle_steering_action(action, event, tool_name) + self._handle_tool_steering_action(action, event, tool_name) - def _handle_steering_action(self, action: SteeringAction, event: BeforeToolCallEvent, tool_name: str) -> None: - """Handle the steering action by modifying tool execution flow. + def _handle_tool_steering_action(self, action: SteeringAction, event: BeforeToolCallEvent, tool_name: str) -> None: + """Handle the steering action for tool calls by modifying tool execution flow. Proceed: Tool executes normally Guide: Tool cancelled with contextual feedback for agent to consider alternatives @@ -114,11 +127,52 @@ def _handle_steering_action(self, action: SteeringAction, event: BeforeToolCallE else: logger.debug("tool_name=<%s> | tool call approved manually", tool_name) else: - raise ValueError(f"Unknown steering action type: {action}") + raise ValueError(f"Unknown steering action type for tool call: {action}") + + async def _provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: + """Provide steering guidance for model response.""" + logger.debug("providing model steering guidance") + + # Only steer on successful model responses + if event.stop_response is None: + logger.debug("no stop response available | skipping model steering") + return + + try: + action = await self.steer_after_model( + agent=event.agent, message=event.stop_response.message, stop_reason=event.stop_response.stop_reason + ) + except Exception as e: + logger.debug("error=<%s> | model steering handler guidance failed", e) + return + + await self._handle_model_steering_action(action, event) + + async def _handle_model_steering_action(self, action: ModelSteeringAction, event: AfterModelCallEvent) -> None: + """Handle the steering action for model responses by modifying response handling flow. - @abstractmethod - async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> SteeringAction: - """Provide contextual guidance to help agent navigate complex workflows. + Proceed: Model response accepted without modification + Guide: Discard model response and retry with guidance message added to conversation + """ + if isinstance(action, Proceed): + logger.debug("model response proceeding") + elif isinstance(action, Guide): + logger.debug("model response guided (retrying): %s", action.reason) + # Set retry flag to discard current response + event.retry = True + # Add guidance message to agent's conversation so model sees it on retry + await event.agent._append_messages({"role": "user", "content": [{"text": action.reason}]}) + logger.debug("added guidance message to conversation for model retry") + else: + raise ValueError(f"Unknown steering action type for model response: {action}") + + async def steer_before_tool(self, *, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> ToolSteeringAction: + """Provide contextual guidance before tool execution. + + This method is called before a tool is executed, allowing the handler to: + - Proceed: Allow tool execution to continue + - Guide: Cancel tool and provide feedback for alternative approaches + - Interrupt: Pause for human input before tool execution Args: agent: The agent instance @@ -126,9 +180,38 @@ async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> Steer **kwargs: Additional keyword arguments for guidance evaluation Returns: - SteeringAction indicating how to guide the agent's next action + ToolSteeringAction indicating how to guide the tool execution + + Note: + Access steering context via self.steering_context + Default implementation returns Proceed (allow tool execution) + Override this method to implement custom tool steering logic + """ + return Proceed(reason="Default implementation: allowing tool execution") + + async def steer_after_model( + self, *, agent: "Agent", message: Message, stop_reason: StopReason, **kwargs: Any + ) -> ModelSteeringAction: + """Provide contextual guidance after model response. + + This method is called after the model generates a response, allowing the handler to: + - Proceed: Accept the model response without modification + - Guide: Discard the response and retry (message is dropped, model is called again) + + Note: Interrupt is not supported for model steering as the model has already responded. + + Args: + agent: The agent instance + message: The model's generated message + stop_reason: The reason the model stopped generating + **kwargs: Additional keyword arguments for guidance evaluation + + Returns: + ModelSteeringAction indicating how to handle the model response Note: Access steering context via self.steering_context + Default implementation returns Proceed (accept response as-is) + Override this method to implement custom model steering logic """ - ... + return Proceed(reason="Default implementation: accepting model response") diff --git a/src/strands/experimental/steering/handlers/llm/llm_handler.py b/src/strands/experimental/steering/handlers/llm/llm_handler.py index 9d9b34911..8b0630a0b 100644 --- a/src/strands/experimental/steering/handlers/llm/llm_handler.py +++ b/src/strands/experimental/steering/handlers/llm/llm_handler.py @@ -10,7 +10,7 @@ from .....models import Model from .....types.tools import ToolUse from ...context_providers.ledger_provider import LedgerProvider -from ...core.action import Guide, Interrupt, Proceed, SteeringAction +from ...core.action import Guide, Interrupt, Proceed, ToolSteeringAction from ...core.context import SteeringContextProvider from ...core.handler import SteeringHandler from .mappers import DefaultPromptMapper, LLMPromptMapper @@ -58,7 +58,7 @@ def __init__( self.prompt_mapper = prompt_mapper or DefaultPromptMapper() self.model = model - async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> SteeringAction: + async def steer_before_tool(self, *, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> ToolSteeringAction: """Provide contextual guidance for tool usage. Args: @@ -67,7 +67,7 @@ async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> Steer **kwargs: Additional keyword arguments for steering evaluation Returns: - SteeringAction indicating how to guide the agent's next action + SteeringAction indicating how to guide the tool execution """ # Generate steering prompt prompt = self.prompt_mapper.create_steering_prompt(self.steering_context, tool_use=tool_use) @@ -91,5 +91,5 @@ async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> Steer case "interrupt": return Interrupt(reason=llm_result.reason) case _: - logger.warning("decision=<%s> | uŹknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] + logger.warning("decision=<%s> | unknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] return Proceed(reason="Unknown LLM decision, defaulting to proceed") diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 8d5ef6884..6080d0a06 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -1,20 +1,20 @@ """Unit tests for steering handler base class.""" -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pytest from strands.experimental.steering.core.action import Guide, Interrupt, Proceed from strands.experimental.steering.core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider from strands.experimental.steering.core.handler import SteeringHandler -from strands.hooks.events import BeforeToolCallEvent +from strands.hooks.events import AfterModelCallEvent, BeforeToolCallEvent from strands.hooks.registry import HookRegistry class TestSteeringHandler(SteeringHandler): """Test implementation of SteeringHandler.""" - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Test proceed") @@ -31,9 +31,9 @@ def test_register_hooks(): handler.register_hooks(registry) - # Verify hooks were registered - assert registry.add_callback.call_count >= 1 - registry.add_callback.assert_any_call(BeforeToolCallEvent, handler._provide_steering_guidance) + # Verify hooks were registered (tool and model steering hooks) + assert registry.add_callback.call_count >= 2 + registry.add_callback.assert_any_call(BeforeToolCallEvent, handler._provide_tool_steering_guidance) def test_steering_context_initialization(): @@ -65,7 +65,7 @@ async def test_proceed_action_flow(): """Test complete flow with Proceed action.""" class ProceedHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Test proceed") handler = ProceedHandler() @@ -73,7 +73,7 @@ async def steer(self, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) # Should not modify event for Proceed assert not event.cancel_tool @@ -84,7 +84,7 @@ async def test_guide_action_flow(): """Test complete flow with Guide action.""" class GuideHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Guide(reason="Test guidance") handler = GuideHandler() @@ -92,7 +92,7 @@ async def steer(self, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) # Should set cancel_tool with guidance message expected_message = "Tool call cancelled given new guidance. Test guidance. Consider this approach and continue" @@ -104,7 +104,7 @@ async def test_interrupt_action_approved_flow(): """Test complete flow with Interrupt action when approved.""" class InterruptHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Interrupt(reason="Need approval") handler = InterruptHandler() @@ -113,7 +113,7 @@ async def steer(self, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=True) # Approved - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) event.interrupt.assert_called_once() @@ -123,7 +123,7 @@ async def test_interrupt_action_denied_flow(): """Test complete flow with Interrupt action when denied.""" class InterruptHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Interrupt(reason="Need approval") handler = InterruptHandler() @@ -132,7 +132,7 @@ async def steer(self, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=False) # Denied - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) event.interrupt.assert_called_once() assert event.cancel_tool.startswith("Manual approval denied:") @@ -143,7 +143,7 @@ async def test_unknown_action_flow(): """Test complete flow with unknown action type raises error.""" class UnknownActionHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Mock() # Not a valid SteeringAction handler = UnknownActionHandler() @@ -152,14 +152,14 @@ async def steer(self, agent, tool_use, **kwargs): event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) with pytest.raises(ValueError, match="Unknown steering action type"): - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) def test_register_steering_hooks_override(): """Test that _register_steering_hooks can be overridden.""" class CustomHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Custom") def register_hooks(self, registry, **kwargs): @@ -200,7 +200,7 @@ def __init__(self, context_callbacks=None): providers = [MockContextProvider(context_callbacks)] if context_callbacks else None super().__init__(context_providers=providers) - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Test proceed") @@ -260,8 +260,8 @@ def test_multiple_context_callbacks_registered(): handler.register_hooks(registry) - # Should register one callback for each context provider plus steering guidance - expected_calls = 2 + 1 # 2 callbacks + 1 for steering guidance + # Should register one callback for each context provider plus tool and model steering guidance + expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) assert registry.add_callback.call_count >= expected_calls @@ -276,3 +276,187 @@ def test_handler_initialization_with_callbacks(): assert len(handler._context_callbacks) == 2 assert callback1 in handler._context_callbacks assert callback2 in handler._context_callbacks + + +# Model steering tests +@pytest.mark.asyncio +async def test_model_steering_proceed_action_flow(): + """Test model steering with Proceed action.""" + + class ModelProceedHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Proceed(reason="Model response accepted") + + handler = ModelProceedHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + event.retry = False + + await handler._provide_model_steering_guidance(event) + + # Should not set retry for Proceed + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_steering_guide_action_flow(): + """Test model steering with Guide action sets retry and adds message.""" + + class ModelGuideHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Guide(reason="Please improve your response") + + handler = ModelGuideHandler() + agent = AsyncMock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + event.retry = False + + await handler._provide_model_steering_guidance(event) + + # Should set retry flag + assert event.retry is True + # Should add guidance message to conversation + agent._append_messages.assert_called_once() + call_args = agent._append_messages.call_args[0][0] + assert call_args["role"] == "user" + assert "Please improve your response" in call_args["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_model_steering_skips_when_no_stop_response(): + """Test model steering skips when stop_response is None.""" + + class ModelProceedHandler(SteeringHandler): + def __init__(self): + super().__init__() + self.steer_called = False + + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + self.steer_called = True + return Proceed(reason="Should not be called") + + handler = ModelProceedHandler() + event = Mock(spec=AfterModelCallEvent) + event.stop_response = None + + await handler._provide_model_steering_guidance(event) + + # steer_after_model should not have been called + assert handler.steer_called is False + + +@pytest.mark.asyncio +async def test_model_steering_unknown_action_raises_error(): + """Test model steering with unknown action type raises error.""" + + class UnknownModelActionHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Mock() # Not a valid ModelSteeringAction + + handler = UnknownModelActionHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + + with pytest.raises(ValueError, match="Unknown steering action type for model response"): + await handler._provide_model_steering_guidance(event) + + +@pytest.mark.asyncio +async def test_model_steering_exception_handling(): + """Test model steering handles exceptions gracefully.""" + + class ExceptionModelHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + raise RuntimeError("Test exception") + + handler = ExceptionModelHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + event.retry = False + + # Should not raise, just return early + await handler._provide_model_steering_guidance(event) + + # retry should not be set since exception occurred + assert event.retry is False + + +@pytest.mark.asyncio +async def test_tool_steering_exception_handling(): + """Test tool steering handles exceptions gracefully.""" + + class ExceptionToolHandler(SteeringHandler): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + raise RuntimeError("Test exception") + + handler = ExceptionToolHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) + + # Should not raise, just return early + await handler._provide_tool_steering_guidance(event) + + # cancel_tool should not be set since exception occurred + assert not event.cancel_tool + + +# Default implementation tests +@pytest.mark.asyncio +async def test_default_steer_before_tool_returns_proceed(): + """Test default steer_before_tool returns Proceed.""" + handler = TestSteeringHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + + # Call the parent's default implementation + result = await SteeringHandler.steer_before_tool(handler, agent=agent, tool_use=tool_use) + + assert isinstance(result, Proceed) + assert "Default implementation" in result.reason + + +@pytest.mark.asyncio +async def test_default_steer_after_model_returns_proceed(): + """Test default steer_after_model returns Proceed.""" + handler = TestSteeringHandler() + agent = Mock() + message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_reason = "end_turn" + + # Call the parent's default implementation + result = await SteeringHandler.steer_after_model(handler, agent=agent, message=message, stop_reason=stop_reason) + + assert isinstance(result, Proceed) + assert "Default implementation" in result.reason + + +def test_register_hooks_registers_model_steering(): + """Test that register_hooks registers model steering callback.""" + handler = TestSteeringHandler() + registry = Mock(spec=HookRegistry) + + handler.register_hooks(registry) + + # Verify model steering hook was registered + registry.add_callback.assert_any_call(AfterModelCallEvent, handler._provide_model_steering_guidance) diff --git a/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py b/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py index f780088b5..f10254e50 100644 --- a/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py +++ b/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py @@ -59,7 +59,7 @@ async def test_steer_proceed_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Proceed) assert result.reason == "Tool call is safe" @@ -82,7 +82,7 @@ async def test_steer_guide_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Guide) assert result.reason == "Consider security implications" @@ -105,7 +105,7 @@ async def test_steer_interrupt_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Interrupt) assert result.reason == "Human approval required" @@ -133,7 +133,7 @@ async def test_steer_unknown_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Proceed) assert "Unknown LLM decision, defaulting to proceed" in result.reason @@ -158,7 +158,7 @@ async def test_steer_uses_custom_model(mock_agent_class): agent.model = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - await handler.steer(agent, tool_use) + await handler.steer_before_tool(agent=agent, tool_use=tool_use) mock_agent_class.assert_called_once_with(system_prompt=system_prompt, model=custom_model, callback_handler=None) @@ -181,7 +181,7 @@ async def test_steer_uses_agent_model_when_no_custom_model(mock_agent_class): agent.model = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - await handler.steer(agent, tool_use) + await handler.steer_before_tool(agent=agent, tool_use=tool_use) mock_agent_class.assert_called_once_with(system_prompt=system_prompt, model=agent.model, callback_handler=None) diff --git a/tests_integ/steering/test_model_steering.py b/tests_integ/steering/test_model_steering.py new file mode 100644 index 000000000..e867ea033 --- /dev/null +++ b/tests_integ/steering/test_model_steering.py @@ -0,0 +1,204 @@ +"""Integration tests for model steering (steer_after_model).""" + +from strands import Agent, tool +from strands.experimental.steering.core.action import Guide, ModelSteeringAction, Proceed +from strands.experimental.steering.core.handler import SteeringHandler +from strands.types.content import Message +from strands.types.streaming import StopReason + + +class SimpleModelSteeringHandler(SteeringHandler): + """Simple handler that steers only on model responses.""" + + def __init__(self, should_guide: bool = False, guidance_message: str = ""): + """Initialize handler. + + Args: + should_guide: If True, guide (retry) on first model response + guidance_message: The guidance message to provide on retry + """ + super().__init__() + self.should_guide = should_guide + self.guidance_message = guidance_message + self.call_count = 0 + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + """Steer after model response.""" + self.call_count += 1 + + # On first call, guide to retry if configured + if self.should_guide and self.call_count == 1: + return Guide(reason=self.guidance_message) + + return Proceed(reason="Model response accepted") + + +def test_model_steering_proceeds_without_intervention(): + """Test that model steering can accept responses without modification.""" + handler = SimpleModelSteeringHandler(should_guide=False) + agent = Agent(hooks=[handler]) + + response = agent("What is 2+2?") + + # Handler should have been called once + assert handler.call_count >= 1 + # Response should be generated successfully + response_text = str(response) + assert response_text is not None + assert len(response_text) > 0 + + +def test_model_steering_guide_triggers_retry(): + """Test that Guide action triggers model retry.""" + handler = SimpleModelSteeringHandler(should_guide=True, guidance_message="Please provide a more detailed response.") + agent = Agent(hooks=[handler]) + + response = agent("What is the capital of France?") + + # Handler should have been called at least twice (first response + retry) + assert handler.call_count >= 2, "Handler should be called on initial response and retry" + + # Response should be generated successfully after retry + response_text = str(response) + assert response_text is not None + assert len(response_text) > 0 + + +def test_model_steering_guide_influences_retry_response(): + """Test that guidance message influences the retry response.""" + + class SpecificGuidanceHandler(SteeringHandler): + def __init__(self): + super().__init__() + self.retry_done = False + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + if not self.retry_done: + self.retry_done = True + # Provide very specific guidance that should appear in retry + return Guide(reason="Please mention that Paris is also known as the 'City of Light'.") + return Proceed(reason="Response is good now") + + handler = SpecificGuidanceHandler() + agent = Agent(hooks=[handler]) + + response = agent("What is the capital of France?") + + # Verify retry happened + assert handler.retry_done, "Retry should have occurred" + + # Check that the response likely incorporated the guidance + output = str(response).lower() + assert "paris" in output, "Response should mention Paris" + + # The guidance should have influenced the retry (check for "light" or that retry happened) + # We can't guarantee the model will include it, but we verify the mechanism worked + assert handler.retry_done, "Guidance mechanism should have executed" + + +def test_model_steering_multiple_retries(): + """Test that model steering can guide multiple times before proceeding.""" + + class MultiRetryHandler(SteeringHandler): + def __init__(self): + super().__init__() + self.call_count = 0 + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + self.call_count += 1 + + # Retry twice + if self.call_count == 1: + return Guide(reason="Please provide more context.") + if self.call_count == 2: + return Guide(reason="Please add specific examples.") + return Proceed(reason="Response is good now") + + handler = MultiRetryHandler() + agent = Agent(hooks=[handler]) + + response = agent("Explain machine learning.") + + # Should have been called 3 times (2 guides + 1 proceed) + assert handler.call_count >= 3, "Handler should be called multiple times for multiple retries" + + # Response should still complete successfully + assert str(response) is not None + assert len(str(response)) > 0 + + +@tool +def log_activity(activity: str) -> str: + """Log an activity for audit purposes.""" + return f"Activity logged: {activity}" + + +def test_model_steering_forces_tool_usage_on_unrelated_prompt(): + """Test that steering forces tool usage even when prompt doesn't need the tool. + + This test verifies the flow: + 1. Agent has a logging tool available + 2. User asks an unrelated question (math problem) + 3. Model tries to answer directly without using the tool + 4. Steering intercepts and forces tool usage before termination + 5. Model uses the tool and then completes + """ + + class ForceToolUsageHandler(SteeringHandler): + """Handler that forces a specific tool to be used before allowing termination.""" + + def __init__(self, required_tool: str): + super().__init__() + self.required_tool = required_tool + self.tool_was_used = False + self.guidance_given = False + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + # Only check when model is trying to end the turn + if stop_reason != "end_turn": + return Proceed(reason="Model still processing") + + # Check if the required tool was used in this message + content_blocks = message.get("content", []) + for block in content_blocks: + if "toolUse" in block and block["toolUse"].get("name") == self.required_tool: + self.tool_was_used = True + return Proceed(reason="Required tool was used") + + # If tool wasn't used and we haven't guided yet, force its usage + if not self.tool_was_used and not self.guidance_given: + self.guidance_given = True + return Guide( + reason=f"Before completing your response, you MUST use the {self.required_tool} tool " + "to log this interaction. Call the tool with a brief description of what you did." + ) + + # Allow completion after guidance was given (model may have used tool in retry) + return Proceed(reason="Guidance was provided") + + handler = ForceToolUsageHandler(required_tool="log_activity") + agent = Agent(tools=[log_activity], hooks=[handler]) + + # Ask a question that clearly doesn't need the logging tool + response = agent("What is 2 + 2?") + + # Verify the steering mechanism worked + assert handler.guidance_given, "Handler should have provided guidance to use the tool" + + # Verify tool was actually called by checking metrics + tool_metrics = response.metrics.tool_metrics + assert "log_activity" in tool_metrics, "log_activity tool should have been called" + assert tool_metrics["log_activity"].call_count >= 1, "log_activity should have been called at least once" + assert tool_metrics["log_activity"].success_count >= 1, "log_activity should have succeeded" + + # Verify the response still answers the original question + output = str(response).lower() + assert "4" in output, "Response should contain the answer to 2+2" diff --git a/tests_integ/steering/test_llm_handler.py b/tests_integ/steering/test_tool_steering.py similarity index 91% rename from tests_integ/steering/test_llm_handler.py rename to tests_integ/steering/test_tool_steering.py index 8a8cebea2..eced94ba0 100644 --- a/tests_integ/steering/test_llm_handler.py +++ b/tests_integ/steering/test_tool_steering.py @@ -1,4 +1,4 @@ -"""Integration tests for LLM steering handler.""" +"""Integration tests for tool steering (steer_before_tool).""" import pytest @@ -30,7 +30,7 @@ async def test_llm_steering_handler_proceed(): agent = Agent(tools=[send_notification]) tool_use = {"name": "send_notification", "input": {"recipient": "user", "message": "hello"}} - effect = await handler.steer(agent, tool_use) + effect = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(effect, Proceed) @@ -48,7 +48,7 @@ async def test_llm_steering_handler_guide(): agent = Agent(tools=[send_email, send_notification]) tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}} - effect = await handler.steer(agent, tool_use) + effect = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(effect, Guide) @@ -64,12 +64,12 @@ async def test_llm_steering_handler_interrupt(): agent = Agent(tools=[send_email]) tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}} - effect = await handler.steer(agent, tool_use) + effect = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(effect, Interrupt) -def test_agent_with_steering_e2e(): +def test_agent_with_tool_steering_e2e(): """End-to-end test of agent with steering handler guiding tool choice.""" handler = LLMSteeringHandler( system_prompt=(