Skip to content

Commit 96bb1ab

Browse files
committed
feat: enable multiagent session persistent
# Conflicts: # src/strands/multiagent/graph.py # src/strands/multiagent/swarm.py # tests/strands/multiagent/test_graph.py # tests/strands/multiagent/test_swarm.py # tests_integ/test_multiagent_graph.py # tests_integ/test_multiagent_swarm.py
1 parent 111e77c commit 96bb1ab

File tree

7 files changed

+636
-22
lines changed

7 files changed

+636
-22
lines changed

src/strands/multiagent/graph.py

Lines changed: 170 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
from .._async import run_async
2727
from ..agent import Agent
2828
from ..agent.state import AgentState
29+
from ..experimental.hooks.multiagent import (
30+
AfterMultiAgentInvocationEvent,
31+
AfterNodeCallEvent,
32+
MultiAgentInitializedEvent,
33+
)
34+
from ..hooks import HookProvider, HookRegistry
35+
from ..session import SessionManager
2936
from ..telemetry import get_tracer
3037
from ..types._events import (
3138
MultiAgentHandoffEvent,
@@ -40,6 +47,8 @@
4047

4148
logger = logging.getLogger(__name__)
4249

50+
_DEFAULT_GRAPH_ID = "default_graph"
51+
4352

4453
@dataclass
4554
class GraphState:
@@ -223,6 +232,9 @@ def __init__(self) -> None:
223232
self._execution_timeout: Optional[float] = None
224233
self._node_timeout: Optional[float] = None
225234
self._reset_on_revisit: bool = False
235+
self._id: str = _DEFAULT_GRAPH_ID
236+
self._session_manager: Optional[SessionManager] = None
237+
self._hooks: Optional[list[HookProvider]] = None
226238

227239
def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode:
228240
"""Add an Agent or MultiAgentBase instance as a node to the graph."""
@@ -313,6 +325,33 @@ def set_node_timeout(self, timeout: float) -> "GraphBuilder":
313325
self._node_timeout = timeout
314326
return self
315327

328+
def set_graph_id(self, graph_id: str) -> "GraphBuilder":
329+
"""Set graph id.
330+
331+
Args:
332+
graph_id: Unique graph id default to uuid4
333+
"""
334+
self._id = graph_id
335+
return self
336+
337+
def set_session_manager(self, session_manager: SessionManager) -> "GraphBuilder":
338+
"""Set session manager for the graph.
339+
340+
Args:
341+
session_manager: SessionManager instance
342+
"""
343+
self._session_manager = session_manager
344+
return self
345+
346+
def set_hook_providers(self, hooks: list[HookProvider]) -> "GraphBuilder":
347+
"""Set hook providers for the graph.
348+
349+
Args:
350+
hooks: Customer hooks user passes in
351+
"""
352+
self._hooks = hooks
353+
return self
354+
316355
def build(self) -> "Graph":
317356
"""Build and validate the graph with configured settings."""
318357
if not self.nodes:
@@ -331,13 +370,16 @@ def build(self) -> "Graph":
331370
self._validate_graph()
332371

333372
return Graph(
373+
id=self._id,
334374
nodes=self.nodes.copy(),
335375
edges=self.edges.copy(),
336376
entry_points=self.entry_points.copy(),
337377
max_node_executions=self._max_node_executions,
338378
execution_timeout=self._execution_timeout,
339379
node_timeout=self._node_timeout,
340380
reset_on_revisit=self._reset_on_revisit,
381+
session_manager=self._session_manager,
382+
hooks=self._hooks,
341383
)
342384

343385
def _validate_graph(self) -> None:
@@ -365,6 +407,10 @@ def __init__(
365407
execution_timeout: Optional[float] = None,
366408
node_timeout: Optional[float] = None,
367409
reset_on_revisit: bool = False,
410+
session_manager: Optional[SessionManager] = None,
411+
hooks: Optional[list[HookProvider]] = None,
412+
*,
413+
id: str = _DEFAULT_GRAPH_ID,
368414
) -> None:
369415
"""Initialize Graph with execution limits and reset behavior.
370416
@@ -376,11 +422,15 @@ def __init__(
376422
execution_timeout: Total execution timeout in seconds (default: None - no limit)
377423
node_timeout: Individual node timeout in seconds (default: None - no limit)
378424
reset_on_revisit: Whether to reset node state when revisited (default: False)
425+
session_manager: Session manager for persisting graph state and execution history (default: None)
426+
hooks: List of hook providers for monitoring and extending graph execution behavior (default: None)
427+
id: Unique graph id (default: None)
379428
"""
380429
super().__init__()
381430

382431
# Validate nodes for duplicate instances
383432
self._validate_graph(nodes)
433+
self.id = id or _DEFAULT_GRAPH_ID
384434

385435
self.nodes = nodes
386436
self.edges = edges
@@ -391,6 +441,18 @@ def __init__(
391441
self.reset_on_revisit = reset_on_revisit
392442
self.state = GraphState()
393443
self.tracer = get_tracer()
444+
self.session_manager = session_manager
445+
self.hooks = HookRegistry()
446+
if self.session_manager:
447+
self.hooks.add_hook(self.session_manager)
448+
if hooks:
449+
for hook in hooks:
450+
self.hooks.add_hook(hook)
451+
452+
self._resume_next_nodes: list[GraphNode] = []
453+
self._resume_from_session = False
454+
455+
self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self))
394456

395457
def __call__(
396458
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
@@ -457,14 +519,19 @@ async def stream_async(
457519

458520
# Initialize state
459521
start_time = time.time()
460-
self.state = GraphState(
461-
status=Status.EXECUTING,
462-
task=task,
463-
total_nodes=len(self.nodes),
464-
edges=[(edge.from_node, edge.to_node) for edge in self.edges],
465-
entry_points=list(self.entry_points),
466-
start_time=start_time,
467-
)
522+
if not self._resume_from_session:
523+
# Initialize state
524+
self.state = GraphState(
525+
status=Status.EXECUTING,
526+
task=task,
527+
total_nodes=len(self.nodes),
528+
edges=[(edge.from_node, edge.to_node) for edge in self.edges],
529+
entry_points=list(self.entry_points),
530+
start_time=start_time,
531+
)
532+
else:
533+
self.state.status = Status.EXECUTING
534+
self.state.start_time = start_time
468535

469536
span = self.tracer.start_multiagent_span(task, "graph")
470537
with trace_api.use_span(span, end_on_exit=True):
@@ -499,6 +566,9 @@ async def stream_async(
499566
raise
500567
finally:
501568
self.state.execution_time = round((time.time() - start_time) * 1000)
569+
self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self))
570+
self._resume_from_session = False
571+
self._resume_next_nodes.clear()
502572

503573
def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
504574
"""Validate graph nodes for duplicate instances."""
@@ -514,7 +584,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
514584

515585
async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
516586
"""Execute graph and yield TypedEvent objects."""
517-
ready_nodes = list(self.entry_points)
587+
ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points)
518588

519589
while ready_nodes:
520590
# Check execution limits before continuing
@@ -928,3 +998,94 @@ def _build_result(self) -> GraphResult:
928998
edges=self.state.edges,
929999
entry_points=self.state.entry_points,
9301000
)
1001+
1002+
def serialize_state(self) -> dict[str, Any]:
1003+
"""Serialize the current graph state to a dictionary."""
1004+
status_str = self.state.status.value
1005+
compute_nodes = self._compute_ready_nodes_for_resume()
1006+
next_nodes = [n.node_id for n in compute_nodes] if compute_nodes else []
1007+
return {
1008+
"type": "graph",
1009+
"id": self.id,
1010+
"status": status_str,
1011+
"completed_nodes": [n.node_id for n in self.state.completed_nodes],
1012+
"failed_nodes": [n.node_id for n in self.state.failed_nodes],
1013+
"node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()},
1014+
"next_nodes_to_execute": next_nodes,
1015+
"current_task": self.state.task,
1016+
"execution_order": [n.node_id for n in self.state.execution_order],
1017+
}
1018+
1019+
def deserialize_state(self, payload: dict[str, Any]) -> None:
1020+
"""Restore graph state from a session dict and prepare for execution.
1021+
1022+
This method handles two scenarios:
1023+
1. If the persisted status is COMPLETED, FAILED resets all nodes and graph state
1024+
to allow re-execution from the beginning.
1025+
2. Otherwise, restores the persisted state and prepares to resume execution
1026+
from the next ready nodes.
1027+
1028+
Args:
1029+
payload: Dictionary containing persisted state data including status,
1030+
completed nodes, results, and next nodes to execute.
1031+
"""
1032+
if not payload.get("next_nodes_to_execute"):
1033+
# Reset all nodes
1034+
for node in self.nodes.values():
1035+
node.reset_executor_state()
1036+
# Reset graph state
1037+
self.state = GraphState()
1038+
self._resume_from_session = False
1039+
return
1040+
else:
1041+
self._from_dict(payload)
1042+
self._resume_from_session = True
1043+
1044+
# Helper functions for serialize and deserialize
1045+
def _compute_ready_nodes_for_resume(self) -> list[GraphNode]:
1046+
if self.state.status == Status.PENDING:
1047+
return []
1048+
ready_nodes: list[GraphNode] = []
1049+
completed_nodes = set(self.state.completed_nodes)
1050+
for node in self.nodes.values():
1051+
if node in completed_nodes:
1052+
continue
1053+
incoming = [e for e in self.edges if e.to_node is node]
1054+
if not incoming:
1055+
ready_nodes.append(node)
1056+
elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming):
1057+
ready_nodes.append(node)
1058+
1059+
return ready_nodes
1060+
1061+
def _from_dict(self, payload: dict[str, Any]) -> None:
1062+
self.state.status = Status(payload["status"])
1063+
# Hydrate completed nodes & results
1064+
raw_results = payload.get("node_results") or {}
1065+
results: dict[str, NodeResult] = {}
1066+
for node_id, entry in raw_results.items():
1067+
if node_id not in self.nodes:
1068+
continue
1069+
try:
1070+
results[node_id] = NodeResult.from_dict(entry)
1071+
except Exception:
1072+
logger.exception("Failed to hydrate NodeResult for node_id=%s; skipping.", node_id)
1073+
raise
1074+
self.state.results = results
1075+
1076+
self.state.failed_nodes = set(payload.get("failed_nodes") or [])
1077+
1078+
# Restore completed nodes from persisted data
1079+
completed_node_ids = payload.get("completed_nodes") or []
1080+
self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes}
1081+
1082+
# Execution order (only nodes that still exist)
1083+
order_node_ids = payload.get("execution_order") or []
1084+
self.state.execution_order = [self.nodes[node_id] for node_id in order_node_ids if node_id in self.nodes]
1085+
1086+
# Task
1087+
self.state.task = payload.get("current_task", self.state.task)
1088+
1089+
# next nodes to execute
1090+
next_nodes = [self.nodes[nid] for nid in (payload.get("next_nodes_to_execute") or []) if nid in self.nodes]
1091+
self._resume_next_nodes = next_nodes

0 commit comments

Comments
 (0)