Skip to content

Commit dbf6200

Browse files
authored
hooks - before tool call event - interrupt (strands-agents#987)
1 parent f7931c5 commit dbf6200

31 files changed

+1401
-44
lines changed

src/strands/agent/agent.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,15 @@
5555
from ..types.agent import AgentInput
5656
from ..types.content import ContentBlock, Message, Messages
5757
from ..types.exceptions import ContextWindowOverflowException
58+
from ..types.interrupt import InterruptResponseContent
5859
from ..types.tools import ToolResult, ToolUse
5960
from ..types.traces import AttributeValue
6061
from .agent_result import AgentResult
6162
from .conversation_manager import (
6263
ConversationManager,
6364
SlidingWindowConversationManager,
6465
)
66+
from .interrupt import InterruptState
6567
from .state import AgentState
6668

6769
logger = logging.getLogger(__name__)
@@ -143,6 +145,9 @@ def caller(
143145
Raises:
144146
AttributeError: If the tool doesn't exist.
145147
"""
148+
if self._agent._interrupt_state.activated:
149+
raise RuntimeError("cannot directly call tool during interrupt")
150+
146151
normalized_name = self._find_normalized_tool_name(name)
147152

148153
# Create unique tool ID and set up the tool request
@@ -338,6 +343,8 @@ def __init__(
338343

339344
self.hooks = HookRegistry()
340345

346+
self._interrupt_state = InterruptState()
347+
341348
# Initialize session management functionality
342349
self._session_manager = session_manager
343350
if self._session_manager:
@@ -491,6 +498,9 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
491498
Raises:
492499
ValueError: If no conversation history or prompt is provided.
493500
"""
501+
if self._interrupt_state.activated:
502+
raise RuntimeError("cannot call structured output during interrupt")
503+
494504
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
495505
with self.tracer.tracer.start_as_current_span(
496506
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
@@ -573,6 +583,8 @@ async def stream_async(
573583
yield event["data"]
574584
```
575585
"""
586+
self._resume_interrupt(prompt)
587+
576588
merged_state = {}
577589
if kwargs:
578590
warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2)
@@ -614,6 +626,38 @@ async def stream_async(
614626
self._end_agent_trace_span(error=e)
615627
raise
616628

629+
def _resume_interrupt(self, prompt: AgentInput) -> None:
630+
"""Configure the interrupt state if resuming from an interrupt event.
631+
632+
Args:
633+
prompt: User responses if resuming from interrupt.
634+
635+
Raises:
636+
TypeError: If in interrupt state but user did not provide responses.
637+
"""
638+
if not self._interrupt_state.activated:
639+
return
640+
641+
if not isinstance(prompt, list):
642+
raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's")
643+
644+
invalid_types = [
645+
content_type for content in prompt for content_type in content if content_type != "interruptResponse"
646+
]
647+
if invalid_types:
648+
raise TypeError(
649+
f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's"
650+
)
651+
652+
for content in cast(list[InterruptResponseContent], prompt):
653+
interrupt_id = content["interruptResponse"]["interruptId"]
654+
interrupt_response = content["interruptResponse"]["response"]
655+
656+
if interrupt_id not in self._interrupt_state.interrupts:
657+
raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found")
658+
659+
self._interrupt_state.interrupts[interrupt_id].response = interrupt_response
660+
617661
async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
618662
"""Execute the agent's event loop with the given message and parameters.
619663
@@ -689,6 +733,9 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A
689733
yield event
690734

691735
def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
736+
if self._interrupt_state.activated:
737+
return []
738+
692739
messages: Messages | None = None
693740
if prompt is not None:
694741
if isinstance(prompt, str):

src/strands/agent/agent_result.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
"""
55

66
from dataclasses import dataclass
7-
from typing import Any
7+
from typing import Any, Sequence
88

9+
from ..interrupt import Interrupt
910
from ..telemetry.metrics import EventLoopMetrics
1011
from ..types.content import Message
1112
from ..types.streaming import StopReason
@@ -20,12 +21,14 @@ class AgentResult:
2021
message: The last message generated by the agent.
2122
metrics: Performance metrics collected during processing.
2223
state: Additional state information from the event loop.
24+
interrupts: List of interrupts if raised by user.
2325
"""
2426

2527
stop_reason: StopReason
2628
message: Message
2729
metrics: EventLoopMetrics
2830
state: Any
31+
interrupts: Sequence[Interrupt] | None = None
2932

3033
def __str__(self) -> str:
3134
"""Get the agent's last message as a string.

src/strands/agent/interrupt.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Track the state of interrupt events raised by the user for human-in-the-loop workflows."""
2+
3+
from dataclasses import asdict, dataclass, field
4+
from typing import Any
5+
6+
from ..interrupt import Interrupt
7+
8+
9+
@dataclass
10+
class InterruptState:
11+
"""Track the state of interrupt events raised by the user.
12+
13+
Note, interrupt state is cleared after resuming.
14+
15+
Attributes:
16+
interrupts: Interrupts raised by the user.
17+
context: Additional context associated with an interrupt event.
18+
activated: True if agent is in an interrupt state, False otherwise.
19+
"""
20+
21+
interrupts: dict[str, Interrupt] = field(default_factory=dict)
22+
context: dict[str, Any] = field(default_factory=dict)
23+
activated: bool = False
24+
25+
def activate(self, context: dict[str, Any] | None = None) -> None:
26+
"""Activate the interrupt state.
27+
28+
Args:
29+
context: Context associated with the interrupt event.
30+
"""
31+
self.context = context or {}
32+
self.activated = True
33+
34+
def deactivate(self) -> None:
35+
"""Deacitvate the interrupt state.
36+
37+
Interrupts and context are cleared.
38+
"""
39+
self.interrupts = {}
40+
self.context = {}
41+
self.activated = False
42+
43+
def to_dict(self) -> dict[str, Any]:
44+
"""Serialize to dict for session management."""
45+
return asdict(self)
46+
47+
@classmethod
48+
def from_dict(cls, data: dict[str, Any]) -> "InterruptState":
49+
"""Initiailize interrupt state from serialized interrupt state.
50+
51+
Interrupt state can be serialized with the `to_dict` method.
52+
"""
53+
return cls(
54+
interrupts={
55+
interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items()
56+
},
57+
context=data["context"],
58+
activated=data["activated"],
59+
)

src/strands/event_loop/event_loop.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ModelStopReason,
2828
StartEvent,
2929
StartEventLoopEvent,
30+
ToolInterruptEvent,
3031
ToolResultMessageEvent,
3132
TypedEvent,
3233
)
@@ -106,13 +107,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
106107
)
107108
invocation_state["event_loop_cycle_span"] = cycle_span
108109

109-
model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer)
110-
async for model_event in model_events:
111-
if not isinstance(model_event, ModelStopReason):
112-
yield model_event
110+
# Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls.
111+
if agent._interrupt_state.activated:
112+
stop_reason: StopReason = "tool_use"
113+
message = agent._interrupt_state.context["tool_use_message"]
113114

114-
stop_reason, message, *_ = model_event["stop"]
115-
yield ModelMessageEvent(message=message)
115+
else:
116+
model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer)
117+
async for model_event in model_events:
118+
if not isinstance(model_event, ModelStopReason):
119+
yield model_event
120+
121+
stop_reason, message, *_ = model_event["stop"]
122+
yield ModelMessageEvent(message=message)
116123

117124
try:
118125
if stop_reason == "max_tokens":
@@ -142,6 +149,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
142149
cycle_span=cycle_span,
143150
cycle_start_time=cycle_start_time,
144151
invocation_state=invocation_state,
152+
tracer=tracer,
145153
)
146154
async for tool_event in tool_events:
147155
yield tool_event
@@ -345,6 +353,7 @@ async def _handle_tool_execution(
345353
cycle_span: Any,
346354
cycle_start_time: float,
347355
invocation_state: dict[str, Any],
356+
tracer: Tracer,
348357
) -> AsyncGenerator[TypedEvent, None]:
349358
"""Handles the execution of tools requested by the model during an event loop cycle.
350359
@@ -356,6 +365,7 @@ async def _handle_tool_execution(
356365
cycle_span: Span object for tracing the cycle (type may vary).
357366
cycle_start_time: Start time of the current cycle.
358367
invocation_state: Additional keyword arguments, including request state.
368+
tracer: Tracer instance for span management.
359369
360370
Yields:
361371
Tool stream events along with events yielded from a recursive call to the event loop. The last event is a tuple
@@ -375,15 +385,45 @@ async def _handle_tool_execution(
375385
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
376386
return
377387

388+
if agent._interrupt_state.activated:
389+
tool_results.extend(agent._interrupt_state.context["tool_results"])
390+
391+
# Filter to only the interrupted tools when resuming from interrupt (tool uses without results)
392+
tool_use_ids = {tool_result["toolUseId"] for tool_result in tool_results}
393+
tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids]
394+
395+
interrupts = []
378396
tool_events = agent.tool_executor._execute(
379397
agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state
380398
)
381399
async for tool_event in tool_events:
400+
if isinstance(tool_event, ToolInterruptEvent):
401+
interrupts.extend(tool_event["tool_interrupt_event"]["interrupts"])
402+
382403
yield tool_event
383404

384405
# Store parent cycle ID for the next cycle
385406
invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"]
386407

408+
if interrupts:
409+
# Session state stored on AfterInvocationEvent.
410+
agent._interrupt_state.activate(context={"tool_use_message": message, "tool_results": tool_results})
411+
412+
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
413+
yield EventLoopStopEvent(
414+
"interrupt",
415+
message,
416+
agent.event_loop_metrics,
417+
invocation_state["request_state"],
418+
interrupts,
419+
)
420+
if cycle_span:
421+
tracer.end_event_loop_cycle_span(span=cycle_span, message=message)
422+
423+
return
424+
425+
agent._interrupt_state.deactivate()
426+
387427
tool_result_message: Message = {
388428
"role": "user",
389429
"content": [{"toolResult": result} for result in tool_results],
@@ -394,7 +434,6 @@ async def _handle_tool_execution(
394434
yield ToolResultMessageEvent(message=tool_result_message)
395435

396436
if cycle_span:
397-
tracer = get_tracer()
398437
tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message)
399438

400439
if invocation_state["request_state"].get("stop_event_loop", False):

src/strands/hooks/events.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
This module defines the events that are emitted as Agents run through the lifecycle of a request.
44
"""
55

6+
import uuid
67
from dataclasses import dataclass
78
from typing import Any, Optional
89

10+
from typing_extensions import override
11+
912
from ..types.content import Message
13+
from ..types.interrupt import InterruptHookEvent
1014
from ..types.streaming import StopReason
1115
from ..types.tools import AgentTool, ToolResult, ToolUse
1216
from .registry import HookEvent
@@ -84,7 +88,7 @@ class MessageAddedEvent(HookEvent):
8488

8589

8690
@dataclass
87-
class BeforeToolCallEvent(HookEvent):
91+
class BeforeToolCallEvent(HookEvent, InterruptHookEvent):
8892
"""Event triggered before a tool is invoked.
8993
9094
This event is fired just before the agent executes a tool, allowing hook
@@ -110,6 +114,18 @@ class BeforeToolCallEvent(HookEvent):
110114
def _can_write(self, name: str) -> bool:
111115
return name in ["cancel_tool", "selected_tool", "tool_use"]
112116

117+
@override
118+
def _interrupt_id(self, name: str) -> str:
119+
"""Unique id for the interrupt.
120+
121+
Args:
122+
name: User defined name for the interrupt.
123+
124+
Returns:
125+
Interrupt id.
126+
"""
127+
return f"v1:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}"
128+
113129

114130
@dataclass
115131
class AfterToolCallEvent(HookEvent):

0 commit comments

Comments
 (0)