2626from .._async import run_async
2727from ..agent import Agent
2828from ..agent .state import AgentState
29+ from ..experimental .hooks .multiagent import (
30+ AfterMultiAgentInvocationEvent ,
31+ AfterNodeCallEvent ,
32+ MultiAgentInitializedEvent ,
33+ )
34+ from ..hooks import HookProvider , HookRegistry
35+ from ..session import SessionManager
2936from ..telemetry import get_tracer
3037from ..types ._events import (
3138 MultiAgentHandoffEvent ,
4047
4148logger = logging .getLogger (__name__ )
4249
50+ _DEFAULT_GRAPH_ID = "default_graph"
51+
4352
4453@dataclass
4554class GraphState :
@@ -223,6 +232,9 @@ def __init__(self) -> None:
223232 self ._execution_timeout : Optional [float ] = None
224233 self ._node_timeout : Optional [float ] = None
225234 self ._reset_on_revisit : bool = False
235+ self ._id : str = _DEFAULT_GRAPH_ID
236+ self ._session_manager : Optional [SessionManager ] = None
237+ self ._hooks : Optional [list [HookProvider ]] = None
226238
227239 def add_node (self , executor : Agent | MultiAgentBase , node_id : str | None = None ) -> GraphNode :
228240 """Add an Agent or MultiAgentBase instance as a node to the graph."""
@@ -313,6 +325,33 @@ def set_node_timeout(self, timeout: float) -> "GraphBuilder":
313325 self ._node_timeout = timeout
314326 return self
315327
328+ def set_graph_id (self , graph_id : str ) -> "GraphBuilder" :
329+ """Set graph id.
330+
331+ Args:
332+ graph_id: Unique graph id default to uuid4
333+ """
334+ self ._id = graph_id
335+ return self
336+
337+ def set_session_manager (self , session_manager : SessionManager ) -> "GraphBuilder" :
338+ """Set session manager for the graph.
339+
340+ Args:
341+ session_manager: SessionManager instance
342+ """
343+ self ._session_manager = session_manager
344+ return self
345+
346+ def set_hook_providers (self , hooks : list [HookProvider ]) -> "GraphBuilder" :
347+ """Set hook providers for the graph.
348+
349+ Args:
350+ hooks: Customer hooks user passes in
351+ """
352+ self ._hooks = hooks
353+ return self
354+
316355 def build (self ) -> "Graph" :
317356 """Build and validate the graph with configured settings."""
318357 if not self .nodes :
@@ -331,13 +370,16 @@ def build(self) -> "Graph":
331370 self ._validate_graph ()
332371
333372 return Graph (
373+ id = self ._id ,
334374 nodes = self .nodes .copy (),
335375 edges = self .edges .copy (),
336376 entry_points = self .entry_points .copy (),
337377 max_node_executions = self ._max_node_executions ,
338378 execution_timeout = self ._execution_timeout ,
339379 node_timeout = self ._node_timeout ,
340380 reset_on_revisit = self ._reset_on_revisit ,
381+ session_manager = self ._session_manager ,
382+ hooks = self ._hooks ,
341383 )
342384
343385 def _validate_graph (self ) -> None :
@@ -365,6 +407,10 @@ def __init__(
365407 execution_timeout : Optional [float ] = None ,
366408 node_timeout : Optional [float ] = None ,
367409 reset_on_revisit : bool = False ,
410+ session_manager : Optional [SessionManager ] = None ,
411+ hooks : Optional [list [HookProvider ]] = None ,
412+ * ,
413+ id : str = _DEFAULT_GRAPH_ID ,
368414 ) -> None :
369415 """Initialize Graph with execution limits and reset behavior.
370416
@@ -376,11 +422,15 @@ def __init__(
376422 execution_timeout: Total execution timeout in seconds (default: None - no limit)
377423 node_timeout: Individual node timeout in seconds (default: None - no limit)
378424 reset_on_revisit: Whether to reset node state when revisited (default: False)
425+ session_manager: Session manager for persisting graph state and execution history (default: None)
426+ hooks: List of hook providers for monitoring and extending graph execution behavior (default: None)
427+ id: Unique graph id (default: None)
379428 """
380429 super ().__init__ ()
381430
382431 # Validate nodes for duplicate instances
383432 self ._validate_graph (nodes )
433+ self .id = id or _DEFAULT_GRAPH_ID
384434
385435 self .nodes = nodes
386436 self .edges = edges
@@ -391,6 +441,18 @@ def __init__(
391441 self .reset_on_revisit = reset_on_revisit
392442 self .state = GraphState ()
393443 self .tracer = get_tracer ()
444+ self .session_manager = session_manager
445+ self .hooks = HookRegistry ()
446+ if self .session_manager :
447+ self .hooks .add_hook (self .session_manager )
448+ if hooks :
449+ for hook in hooks :
450+ self .hooks .add_hook (hook )
451+
452+ self ._resume_next_nodes : list [GraphNode ] = []
453+ self ._resume_from_session = False
454+
455+ self .hooks .invoke_callbacks (MultiAgentInitializedEvent (self ))
394456
395457 def __call__ (
396458 self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
@@ -457,14 +519,19 @@ async def stream_async(
457519
458520 # Initialize state
459521 start_time = time .time ()
460- self .state = GraphState (
461- status = Status .EXECUTING ,
462- task = task ,
463- total_nodes = len (self .nodes ),
464- edges = [(edge .from_node , edge .to_node ) for edge in self .edges ],
465- entry_points = list (self .entry_points ),
466- start_time = start_time ,
467- )
522+ if not self ._resume_from_session :
523+ # Initialize state
524+ self .state = GraphState (
525+ status = Status .EXECUTING ,
526+ task = task ,
527+ total_nodes = len (self .nodes ),
528+ edges = [(edge .from_node , edge .to_node ) for edge in self .edges ],
529+ entry_points = list (self .entry_points ),
530+ start_time = start_time ,
531+ )
532+ else :
533+ self .state .status = Status .EXECUTING
534+ self .state .start_time = start_time
468535
469536 span = self .tracer .start_multiagent_span (task , "graph" )
470537 with trace_api .use_span (span , end_on_exit = True ):
@@ -499,6 +566,9 @@ async def stream_async(
499566 raise
500567 finally :
501568 self .state .execution_time = round ((time .time () - start_time ) * 1000 )
569+ self .hooks .invoke_callbacks (AfterMultiAgentInvocationEvent (self ))
570+ self ._resume_from_session = False
571+ self ._resume_next_nodes .clear ()
502572
503573 def _validate_graph (self , nodes : dict [str , GraphNode ]) -> None :
504574 """Validate graph nodes for duplicate instances."""
@@ -514,7 +584,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
514584
515585 async def _execute_graph (self , invocation_state : dict [str , Any ]) -> AsyncIterator [Any ]:
516586 """Execute graph and yield TypedEvent objects."""
517- ready_nodes = list (self .entry_points )
587+ ready_nodes = self . _resume_next_nodes if self . _resume_from_session else list (self .entry_points )
518588
519589 while ready_nodes :
520590 # Check execution limits before continuing
@@ -928,3 +998,94 @@ def _build_result(self) -> GraphResult:
928998 edges = self .state .edges ,
929999 entry_points = self .state .entry_points ,
9301000 )
1001+
1002+ def serialize_state (self ) -> dict [str , Any ]:
1003+ """Serialize the current graph state to a dictionary."""
1004+ status_str = self .state .status .value
1005+ compute_nodes = self ._compute_ready_nodes_for_resume ()
1006+ next_nodes = [n .node_id for n in compute_nodes ] if compute_nodes else []
1007+ return {
1008+ "type" : "graph" ,
1009+ "id" : self .id ,
1010+ "status" : status_str ,
1011+ "completed_nodes" : [n .node_id for n in self .state .completed_nodes ],
1012+ "failed_nodes" : [n .node_id for n in self .state .failed_nodes ],
1013+ "node_results" : {k : v .to_dict () for k , v in (self .state .results or {}).items ()},
1014+ "next_nodes_to_execute" : next_nodes ,
1015+ "current_task" : self .state .task ,
1016+ "execution_order" : [n .node_id for n in self .state .execution_order ],
1017+ }
1018+
1019+ def deserialize_state (self , payload : dict [str , Any ]) -> None :
1020+ """Restore graph state from a session dict and prepare for execution.
1021+
1022+ This method handles two scenarios:
1023+ 1. If the persisted status is COMPLETED, FAILED resets all nodes and graph state
1024+ to allow re-execution from the beginning.
1025+ 2. Otherwise, restores the persisted state and prepares to resume execution
1026+ from the next ready nodes.
1027+
1028+ Args:
1029+ payload: Dictionary containing persisted state data including status,
1030+ completed nodes, results, and next nodes to execute.
1031+ """
1032+ if not payload .get ("next_nodes_to_execute" ):
1033+ # Reset all nodes
1034+ for node in self .nodes .values ():
1035+ node .reset_executor_state ()
1036+ # Reset graph state
1037+ self .state = GraphState ()
1038+ self ._resume_from_session = False
1039+ return
1040+ else :
1041+ self ._from_dict (payload )
1042+ self ._resume_from_session = True
1043+
1044+ # Helper functions for serialize and deserialize
1045+ def _compute_ready_nodes_for_resume (self ) -> list [GraphNode ]:
1046+ if self .state .status == Status .PENDING :
1047+ return []
1048+ ready_nodes : list [GraphNode ] = []
1049+ completed_nodes = set (self .state .completed_nodes )
1050+ for node in self .nodes .values ():
1051+ if node in completed_nodes :
1052+ continue
1053+ incoming = [e for e in self .edges if e .to_node is node ]
1054+ if not incoming :
1055+ ready_nodes .append (node )
1056+ elif all (e .from_node in completed_nodes and e .should_traverse (self .state ) for e in incoming ):
1057+ ready_nodes .append (node )
1058+
1059+ return ready_nodes
1060+
1061+ def _from_dict (self , payload : dict [str , Any ]) -> None :
1062+ self .state .status = Status (payload ["status" ])
1063+ # Hydrate completed nodes & results
1064+ raw_results = payload .get ("node_results" ) or {}
1065+ results : dict [str , NodeResult ] = {}
1066+ for node_id , entry in raw_results .items ():
1067+ if node_id not in self .nodes :
1068+ continue
1069+ try :
1070+ results [node_id ] = NodeResult .from_dict (entry )
1071+ except Exception :
1072+ logger .exception ("Failed to hydrate NodeResult for node_id=%s; skipping." , node_id )
1073+ raise
1074+ self .state .results = results
1075+
1076+ self .state .failed_nodes = set (payload .get ("failed_nodes" ) or [])
1077+
1078+ # Restore completed nodes from persisted data
1079+ completed_node_ids = payload .get ("completed_nodes" ) or []
1080+ self .state .completed_nodes = {self .nodes [node_id ] for node_id in completed_node_ids if node_id in self .nodes }
1081+
1082+ # Execution order (only nodes that still exist)
1083+ order_node_ids = payload .get ("execution_order" ) or []
1084+ self .state .execution_order = [self .nodes [node_id ] for node_id in order_node_ids if node_id in self .nodes ]
1085+
1086+ # Task
1087+ self .state .task = payload .get ("current_task" , self .state .task )
1088+
1089+ # next nodes to execute
1090+ next_nodes = [self .nodes [nid ] for nid in (payload .get ("next_nodes_to_execute" ) or []) if nid in self .nodes ]
1091+ self ._resume_next_nodes = next_nodes
0 commit comments