1- use std:: collections:: HashMap ;
1+ use std:: {
2+ collections:: { BTreeSet , HashMap , VecDeque } ,
3+ iter:: FusedIterator ,
4+ } ;
25
36use itertools:: { Either , Itertools } ;
47
@@ -16,7 +19,7 @@ use hugr_core::{
1619 ops:: { OpTag , OpTrait , OpType } ,
1720} ;
1821
19- use crate :: CommitId ;
22+ use crate :: { CommitId , persistent_hugr :: NodeStatus } ;
2023
2124use super :: {
2225 InvalidCommit , PatchNode , PersistentHugr , PersistentReplacement , state_space:: CommitData ,
@@ -100,14 +103,18 @@ impl HugrView for PersistentHugr {
100103 }
101104
102105 fn get_parent ( & self , node : Self :: Node ) -> Option < Self :: Node > {
103- assert ! ( self . contains_node( node) , "invalid node" ) ;
104- let ( hugr, node_map) = self . apply_all ( ) ;
105- let parent = hugr. get_parent ( node_map[ & node] ) ?;
106- let parent_inv = node_map
107- . iter ( )
108- . find_map ( |( & k, & v) | ( v == parent) . then_some ( k) )
109- . expect ( "parent not found in node map" ) ;
110- Some ( parent_inv)
106+ debug_assert ! ( self . contains_node( node) , "invalid node" ) ;
107+
108+ if node. owner ( ) == self . base ( ) {
109+ self . base_hugr ( )
110+ . get_parent ( node. 1 )
111+ . map ( |n| PatchNode ( self . base ( ) , n) )
112+ } else {
113+ // all nodes in children commits are applied on the sibling DFG of the
114+ // entrypoint
115+ // TODO: generalise this for the case that commits introduce nested DFGs.
116+ Some ( self . entrypoint ( ) )
117+ }
111118 }
112119
113120 fn get_optype ( & self , PatchNode ( commit_id, node) : Self :: Node ) -> & OpType {
@@ -216,24 +223,16 @@ impl HugrView for PersistentHugr {
216223
217224 fn children ( & self , node : Self :: Node ) -> impl DoubleEndedIterator < Item = Self :: Node > + Clone {
218225 // Only children of dataflow parents may change
226+ let cm = self . get_commit ( node. owner ( ) ) ;
227+ let commit_hugr = cm. commit_hugr ( ) ;
228+ let children = commit_hugr. children ( node. 1 ) . map ( |n| cm. to_patch_node ( n) ) ;
219229 if OpTag :: DataflowParent . is_superset ( self . get_optype ( node) . tag ( ) ) {
220- // TODO: make this more efficient by traversing from inputs to outputs
221- let ( hugr, node_map) = self . apply_all ( ) ;
222- let children = hugr. children ( node_map[ & node] ) . collect_vec ( ) ;
223- let inv_node_map: HashMap < _ , _ > = node_map. into_iter ( ) . map ( |( k, v) | ( v, k) ) . collect ( ) ;
224- Either :: Right ( children. into_iter ( ) . map ( move |child| {
225- * inv_node_map
226- . get ( & child)
227- . expect ( "node not found in node map" )
228- } ) )
230+ // we must filter out children nodes that are invalidated by later commits, and
231+ // on the other hand add nodes in those commits
232+ Either :: Left ( IterValidNodes :: new ( self , children. fuse ( ) ) )
229233 } else {
230- // children are children of the commit hugr
231- let cm = self . get_commit ( node. owner ( ) ) ;
232- Either :: Left (
233- cm. commit_hugr ( )
234- . children ( node. 1 )
235- . map ( |n| cm. to_patch_node ( n) ) ,
236- )
234+ // children are precisely children of the commit hugr
235+ Either :: Right ( children)
237236 }
238237 }
239238
@@ -341,6 +340,79 @@ impl HugrView for PersistentHugr {
341340 }
342341}
343342
343+ /// An iterator over nodes in a `PersistentHugr` that filters out invalid nodes.
344+ ///
345+ /// For any invalid node encountered, it will traverse and return the nodes in
346+ /// the commit deleting the node instead.
347+ #[ derive( Debug , Clone ) ]
348+ pub struct IterValidNodes < ' a , I > {
349+ /// The original iterator over nodes.
350+ nodes_iter : I ,
351+ /// Nodes discovered in commits deleting nodes in the original iterator.
352+ discovered_nodes : VecDeque < PatchNode > ,
353+ /// Commits discovered that delete nodes in the original iterator.
354+ discovered_commits : VecDeque < CommitId > ,
355+ /// Commits discovered across all time, to make sure we only process each
356+ /// commit once.
357+ processed_commits : BTreeSet < CommitId > ,
358+ /// The persistent hugr that the nodes belong to.
359+ hugr : & ' a PersistentHugr ,
360+ }
361+
362+ impl < ' a , I > IterValidNodes < ' a , I > {
363+ fn new ( hugr : & ' a PersistentHugr , nodes_iter : impl IntoIterator < IntoIter = I > ) -> Self {
364+ Self {
365+ nodes_iter : nodes_iter. into_iter ( ) ,
366+ discovered_nodes : VecDeque :: new ( ) ,
367+ discovered_commits : VecDeque :: new ( ) ,
368+ processed_commits : BTreeSet :: new ( ) ,
369+ hugr,
370+ }
371+ }
372+ }
373+
374+ impl < I : FusedIterator < Item = PatchNode > > Iterator for IterValidNodes < ' _ , I > {
375+ type Item = PatchNode ;
376+
377+ fn next ( & mut self ) -> Option < Self :: Item > {
378+ loop {
379+ let Some ( node) = self
380+ . nodes_iter
381+ . next ( )
382+ . or_else ( || self . discovered_nodes . pop_front ( ) )
383+ else {
384+ break ;
385+ } ;
386+ match self . hugr . node_status ( node) {
387+ NodeStatus :: Deleted ( commit_id) => {
388+ if self . processed_commits . insert ( commit_id) {
389+ self . discovered_commits . push_back ( commit_id) ;
390+ }
391+ }
392+ NodeStatus :: ReplacementIO | NodeStatus :: Valid => return Some ( node) ,
393+ }
394+ }
395+
396+ // Add nodes in next commit to queue
397+ let next_commit_id = self . discovered_commits . pop_front ( ) ?;
398+ let next_commit = self . hugr . get_commit ( next_commit_id) ;
399+
400+ self . discovered_nodes . extend (
401+ next_commit
402+ . inserted_nodes ( )
403+ . map ( |n| next_commit. to_patch_node ( n) ) ,
404+ ) ;
405+
406+ self . next ( )
407+ }
408+ }
409+
410+ impl < I : FusedIterator < Item = PatchNode > > DoubleEndedIterator for IterValidNodes < ' _ , I > {
411+ fn next_back ( & mut self ) -> Option < Self :: Item > {
412+ unimplemented ! ( "cannot go backwards" )
413+ }
414+ }
415+
344416#[ cfg( test) ]
345417mod tests {
346418 use std:: collections:: HashSet ;
0 commit comments