Skip to content

Commit d02aa6c

Browse files
committed
feat(persistent): More efficient HugrView iterators for PersistentHugr
1 parent 1e1bfcc commit d02aa6c

File tree

4 files changed

+177
-62
lines changed

4 files changed

+177
-62
lines changed

hugr-persistent/src/persistent_hugr.rs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::{
2+
cell::RefCell,
23
collections::{BTreeSet, HashMap, VecDeque},
34
vec,
45
};
@@ -14,6 +15,9 @@ use crate::{
1415
Commit, CommitData, CommitId, CommitStateSpace, InvalidCommit, PatchNode, PersistentReplacement,
1516
};
1617

18+
mod cache;
19+
use cache::PersistentHugrCache;
20+
1721
pub mod serial;
1822

1923
/// A HUGR-like object that tracks its mutation history.
@@ -56,7 +60,7 @@ pub mod serial;
5660
/// subgraphs within dataflow regions are supported.
5761
///
5862
/// [`SimpleReplacement`]: hugr_core::SimpleReplacement
59-
#[derive(Clone, Debug)]
63+
#[derive(Clone, derive_more::Debug)]
6064
pub struct PersistentHugr {
6165
/// The state space of all commits.
6266
///
@@ -72,6 +76,9 @@ pub struct PersistentHugr {
7276
/// other commits are [`CommitData::Replacement`]s, and are descendants
7377
/// of this.
7478
base_commit_id: CommitId,
79+
/// Cache some properties for performance.
80+
#[debug(skip)]
81+
cache: RefCell<PersistentHugrCache>,
7582
}
7683

7784
impl PersistentHugr {
@@ -144,6 +151,7 @@ impl PersistentHugr {
144151
Ok(Self {
145152
graph,
146153
base_commit_id: base_commit,
154+
cache: RefCell::new(PersistentHugrCache::default()),
147155
})
148156
}
149157

@@ -247,6 +255,11 @@ impl PersistentHugr {
247255
}
248256

249257
self.graph.insert_node(new_commit.clone().into());
258+
259+
// Invalidate cache
260+
for parent in new_commit.parents() {
261+
self.cache.borrow_mut().invalidate_children(parent.id());
262+
}
250263
}
251264

252265
Ok(())
@@ -331,7 +344,11 @@ impl PersistentHugr {
331344
}
332345

333346
pub fn children_commits(&self, commit_id: CommitId) -> impl Iterator<Item = CommitId> + '_ {
334-
self.graph.children(commit_id)
347+
self.cache
348+
.borrow_mut()
349+
.children_or_insert(commit_id, || self.graph.children(commit_id).collect())
350+
.clone()
351+
.into_iter()
335352
}
336353

337354
pub fn parent_commits(&self, commit_id: CommitId) -> impl Iterator<Item = CommitId> + '_ {
@@ -461,6 +478,21 @@ fn get_ancestors_while<'a>(
461478
all_commits
462479
}
463480

481+
/// A node in a commit of a [`PersistentHugr`] is either a valid node of the
482+
/// HUGR, a node deleted by a child commit in that [`PersistentHugr`], or an
483+
/// input or output node in a replacement graph.
484+
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
485+
pub enum NodeStatus {
486+
/// A node deleted by a child commit in that [`PersistentHugr`].
487+
///
488+
/// The ID of the child commit is stored in the variant.
489+
Deleted(CommitId),
490+
/// An input or output node in the replacement graph of a Commit
491+
ReplacementIO,
492+
/// A valid node in the [`PersistentHugr`]
493+
Valid,
494+
}
495+
464496
// non-public methods
465497
impl PersistentHugr {
466498
/// Convert a node ID specific to a commit HUGR into a patch node in the
@@ -480,6 +512,26 @@ impl PersistentHugr {
480512
child.deleted_parent_nodes().contains(&node)
481513
})
482514
}
515+
516+
/// Whether a node is valid in `self`, is deleted or is an IO node in a
517+
/// replacement graph.
518+
pub(crate) fn node_status(
519+
&self,
520+
per_node @ PatchNode(commit_id, node): PatchNode,
521+
) -> NodeStatus {
522+
debug_assert!(self.contains_id(commit_id), "unknown commit");
523+
if self
524+
.get_commit(commit_id)
525+
.replacement()
526+
.is_some_and(|repl| repl.get_replacement_io().contains(&node))
527+
{
528+
NodeStatus::ReplacementIO
529+
} else if let Some(commit_id) = self.find_deleting_commit(per_node) {
530+
NodeStatus::Deleted(commit_id)
531+
} else {
532+
NodeStatus::Valid
533+
}
534+
}
483535
}
484536

485537
impl<'a> IntoIterator for &'a PersistentHugr {
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use std::collections::HashMap;
2+
3+
use crate::CommitId;
4+
5+
/// A cache for storing computed properties of a `PersistentHugr`.
6+
#[derive(Debug, Default, Clone)]
7+
pub(super) struct PersistentHugrCache {
8+
children_cache: HashMap<CommitId, Vec<CommitId>>,
9+
}
10+
11+
impl PersistentHugrCache {
12+
pub fn invalidate_children(&mut self, commit: CommitId) {
13+
self.children_cache.remove(&commit);
14+
}
15+
16+
pub fn children_or_insert(
17+
&mut self,
18+
commit: CommitId,
19+
children: impl FnOnce() -> Vec<CommitId>,
20+
) -> &Vec<CommitId> {
21+
self.children_cache.entry(commit).or_insert_with(children)
22+
}
23+
}

hugr-persistent/src/trait_impls.rs

Lines changed: 98 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use std::collections::HashMap;
1+
use std::{
2+
collections::{BTreeSet, HashMap, VecDeque},
3+
iter::FusedIterator,
4+
};
25

36
use itertools::{Either, Itertools};
47

@@ -16,7 +19,7 @@ use hugr_core::{
1619
ops::{OpTag, OpTrait, OpType},
1720
};
1821

19-
use crate::CommitId;
22+
use crate::{CommitId, persistent_hugr::NodeStatus};
2023

2124
use super::{
2225
InvalidCommit, PatchNode, PersistentHugr, PersistentReplacement, state_space::CommitData,
@@ -100,14 +103,18 @@ impl HugrView for PersistentHugr {
100103
}
101104

102105
fn get_parent(&self, node: Self::Node) -> Option<Self::Node> {
103-
assert!(self.contains_node(node), "invalid node");
104-
let (hugr, node_map) = self.apply_all();
105-
let parent = hugr.get_parent(node_map[&node])?;
106-
let parent_inv = node_map
107-
.iter()
108-
.find_map(|(&k, &v)| (v == parent).then_some(k))
109-
.expect("parent not found in node map");
110-
Some(parent_inv)
106+
debug_assert!(self.contains_node(node), "invalid node");
107+
108+
if node.owner() == self.base() {
109+
self.base_hugr()
110+
.get_parent(node.1)
111+
.map(|n| PatchNode(self.base(), n))
112+
} else {
113+
// all nodes in children commits are applied on the sibling DFG of the
114+
// entrypoint
115+
// TODO: generalise this for the case that commits introduce nested DFGs.
116+
Some(self.entrypoint())
117+
}
111118
}
112119

113120
fn get_optype(&self, PatchNode(commit_id, node): Self::Node) -> &OpType {
@@ -216,24 +223,16 @@ impl HugrView for PersistentHugr {
216223

217224
fn children(&self, node: Self::Node) -> impl DoubleEndedIterator<Item = Self::Node> + Clone {
218225
// Only children of dataflow parents may change
226+
let cm = self.get_commit(node.owner());
227+
let commit_hugr = cm.commit_hugr();
228+
let children = commit_hugr.children(node.1).map(|n| cm.to_patch_node(n));
219229
if OpTag::DataflowParent.is_superset(self.get_optype(node).tag()) {
220-
// TODO: make this more efficient by traversing from inputs to outputs
221-
let (hugr, node_map) = self.apply_all();
222-
let children = hugr.children(node_map[&node]).collect_vec();
223-
let inv_node_map: HashMap<_, _> = node_map.into_iter().map(|(k, v)| (v, k)).collect();
224-
Either::Right(children.into_iter().map(move |child| {
225-
*inv_node_map
226-
.get(&child)
227-
.expect("node not found in node map")
228-
}))
230+
// we must filter out children nodes that are invalidated by later commits, and
231+
// on the other hand add nodes in those commits
232+
Either::Left(IterValidNodes::new(self, children.fuse()))
229233
} else {
230-
// children are children of the commit hugr
231-
let cm = self.get_commit(node.owner());
232-
Either::Left(
233-
cm.commit_hugr()
234-
.children(node.1)
235-
.map(|n| cm.to_patch_node(n)),
236-
)
234+
// children are precisely children of the commit hugr
235+
Either::Right(children)
237236
}
238237
}
239238

@@ -341,6 +340,79 @@ impl HugrView for PersistentHugr {
341340
}
342341
}
343342

343+
/// An iterator over nodes in a `PersistentHugr` that filters out invalid nodes.
344+
///
345+
/// For any invalid node encountered, it will traverse and return the nodes in
346+
/// the commit deleting the node instead.
347+
#[derive(Debug, Clone)]
348+
pub struct IterValidNodes<'a, I> {
349+
/// The original iterator over nodes.
350+
nodes_iter: I,
351+
/// Nodes discovered in commits deleting nodes in the original iterator.
352+
discovered_nodes: VecDeque<PatchNode>,
353+
/// Commits discovered that delete nodes in the original iterator.
354+
discovered_commits: VecDeque<CommitId>,
355+
/// Commits discovered across all time, to make sure we only process each
356+
/// commit once.
357+
processed_commits: BTreeSet<CommitId>,
358+
/// The persistent hugr that the nodes belong to.
359+
hugr: &'a PersistentHugr,
360+
}
361+
362+
impl<'a, I> IterValidNodes<'a, I> {
363+
fn new(hugr: &'a PersistentHugr, nodes_iter: impl IntoIterator<IntoIter = I>) -> Self {
364+
Self {
365+
nodes_iter: nodes_iter.into_iter(),
366+
discovered_nodes: VecDeque::new(),
367+
discovered_commits: VecDeque::new(),
368+
processed_commits: BTreeSet::new(),
369+
hugr,
370+
}
371+
}
372+
}
373+
374+
impl<I: FusedIterator<Item = PatchNode>> Iterator for IterValidNodes<'_, I> {
375+
type Item = PatchNode;
376+
377+
fn next(&mut self) -> Option<Self::Item> {
378+
loop {
379+
let Some(node) = self
380+
.nodes_iter
381+
.next()
382+
.or_else(|| self.discovered_nodes.pop_front())
383+
else {
384+
break;
385+
};
386+
match self.hugr.node_status(node) {
387+
NodeStatus::Deleted(commit_id) => {
388+
if self.processed_commits.insert(commit_id) {
389+
self.discovered_commits.push_back(commit_id);
390+
}
391+
}
392+
NodeStatus::ReplacementIO | NodeStatus::Valid => return Some(node),
393+
}
394+
}
395+
396+
// Add nodes in next commit to queue
397+
let next_commit_id = self.discovered_commits.pop_front()?;
398+
let next_commit = self.hugr.get_commit(next_commit_id);
399+
400+
self.discovered_nodes.extend(
401+
next_commit
402+
.inserted_nodes()
403+
.map(|n| next_commit.to_patch_node(n)),
404+
);
405+
406+
self.next()
407+
}
408+
}
409+
410+
impl<I: FusedIterator<Item = PatchNode>> DoubleEndedIterator for IterValidNodes<'_, I> {
411+
fn next_back(&mut self) -> Option<Self::Item> {
412+
unimplemented!("cannot go backwards")
413+
}
414+
}
415+
344416
#[cfg(test)]
345417
mod tests {
346418
use std::collections::HashSet;

hugr-persistent/src/wire.rs

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use hugr_core::{
66
};
77
use itertools::Itertools;
88

9-
use crate::{CommitId, PatchNode, PersistentHugr, Walker};
9+
use crate::{CommitId, PatchNode, PersistentHugr, Walker, persistent_hugr::NodeStatus};
1010

1111
/// A wire in a [`PersistentHugr`].
1212
///
@@ -59,43 +59,11 @@ impl CommitWire {
5959
}
6060
}
6161

62-
/// A node in a commit of a [`PersistentHugr`] is either a valid node of the
63-
/// HUGR, a node deleted by a child commit in that [`PersistentHugr`], or an
64-
/// input or output node in a replacement graph.
65-
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
66-
enum NodeStatus {
67-
/// A node deleted by a child commit in that [`PersistentHugr`].
68-
///
69-
/// The ID of the child commit is stored in the variant.
70-
Deleted(CommitId),
71-
/// An input or output node in the replacement graph of a Commit
72-
ReplacementIO,
73-
/// A valid node in the [`PersistentHugr`]
74-
Valid,
75-
}
76-
7762
impl PersistentHugr {
7863
pub fn get_wire(&self, node: PatchNode, port: impl Into<Port>) -> PersistentWire {
7964
PersistentWire::from_port(node, port, self)
8065
}
8166

82-
/// Whether a node is valid in `self`, is deleted or is an IO node in a
83-
/// replacement graph.
84-
fn node_status(&self, per_node @ PatchNode(commit_id, node): PatchNode) -> NodeStatus {
85-
debug_assert!(self.contains_id(commit_id), "unknown commit");
86-
if self
87-
.get_commit(commit_id)
88-
.replacement()
89-
.is_some_and(|repl| repl.get_replacement_io().contains(&node))
90-
{
91-
NodeStatus::ReplacementIO
92-
} else if let Some(commit_id) = self.find_deleting_commit(per_node) {
93-
NodeStatus::Deleted(commit_id)
94-
} else {
95-
NodeStatus::Valid
96-
}
97-
}
98-
9967
/// The unique outgoing port in `self` that `port` is attached to.
10068
///
10169
/// # Panics
@@ -132,7 +100,7 @@ impl PersistentHugr {
132100
impl PersistentWire {
133101
/// Get the wire connected to a specified port of a pinned node in `hugr`.
134102
fn from_port(node: PatchNode, port: impl Into<Port>, per_hugr: &PersistentHugr) -> Self {
135-
assert!(per_hugr.contains_node(node), "node not in hugr");
103+
debug_assert!(per_hugr.contains_node(node), "node not in hugr");
136104

137105
// Queue of wires within each commit HUGR, that combined will form the
138106
// persistent wire.

0 commit comments

Comments
 (0)