Skip to content

Commit d9e4c1e

Browse files
committed
feat(persistent): More efficient HugrView iterators for PersistentHugr
1 parent 12b2601 commit d9e4c1e

File tree

4 files changed

+177
-62
lines changed

4 files changed

+177
-62
lines changed

hugr-persistent/src/persistent_hugr.rs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::{
2+
cell::RefCell,
23
collections::{BTreeSet, HashMap, VecDeque},
34
vec,
45
};
@@ -14,6 +15,9 @@ use crate::{
1415
Commit, CommitData, CommitId, CommitStateSpace, InvalidCommit, PatchNode, PersistentReplacement,
1516
};
1617

18+
mod cache;
19+
use cache::PersistentHugrCache;
20+
1721
pub mod serial;
1822

1923
/// A HUGR-like object that tracks its mutation history.
@@ -56,7 +60,7 @@ pub mod serial;
5660
/// subgraphs within dataflow regions are supported.
5761
///
5862
/// [`SimpleReplacement`]: hugr_core::SimpleReplacement
59-
#[derive(Clone, Debug)]
63+
#[derive(Clone, derive_more::Debug)]
6064
pub struct PersistentHugr {
6165
/// The state space of all commits.
6266
///
@@ -75,6 +79,9 @@ pub struct PersistentHugr {
7579
/// Invariant: any path from any commit in `self` through ancestors will
7680
/// always lead to this commit.
7781
base_commit_id: CommitId,
82+
/// Cache some properties for performance.
83+
#[debug(skip)]
84+
cache: RefCell<PersistentHugrCache>,
7885
}
7986

8087
impl PersistentHugr {
@@ -146,6 +153,7 @@ impl PersistentHugr {
146153
Ok(Self {
147154
graph,
148155
base_commit_id: base_commit,
156+
cache: RefCell::new(PersistentHugrCache::default()),
149157
})
150158
}
151159

@@ -249,6 +257,11 @@ impl PersistentHugr {
249257
}
250258

251259
self.graph.insert_node(new_commit.clone().into());
260+
261+
// Invalidate cache
262+
for parent in new_commit.parents() {
263+
self.cache.borrow_mut().invalidate_children(parent.id());
264+
}
252265
}
253266

254267
Ok(())
@@ -333,7 +346,11 @@ impl PersistentHugr {
333346
}
334347

335348
pub fn children_commits(&self, commit_id: CommitId) -> impl Iterator<Item = CommitId> + '_ {
336-
self.graph.children(commit_id)
349+
self.cache
350+
.borrow_mut()
351+
.children_or_insert(commit_id, || self.graph.children(commit_id).collect())
352+
.clone()
353+
.into_iter()
337354
}
338355

339356
pub fn parent_commits(&self, commit_id: CommitId) -> impl Iterator<Item = CommitId> + '_ {
@@ -463,6 +480,21 @@ fn get_ancestors_while<'a>(
463480
all_commits
464481
}
465482

483+
/// A node in a commit of a [`PersistentHugr`] is either a valid node of the
484+
/// HUGR, a node deleted by a child commit in that [`PersistentHugr`], or an
485+
/// input or output node in a replacement graph.
486+
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
487+
pub enum NodeStatus {
488+
/// A node deleted by a child commit in that [`PersistentHugr`].
489+
///
490+
/// The ID of the child commit is stored in the variant.
491+
Deleted(CommitId),
492+
/// An input or output node in the replacement graph of a Commit
493+
ReplacementIO,
494+
/// A valid node in the [`PersistentHugr`]
495+
Valid,
496+
}
497+
466498
// non-public methods
467499
impl PersistentHugr {
468500
/// Convert a node ID specific to a commit HUGR into a patch node in the
@@ -482,6 +514,26 @@ impl PersistentHugr {
482514
child.deleted_parent_nodes().contains(&node)
483515
})
484516
}
517+
518+
/// Whether a node is valid in `self`, is deleted or is an IO node in a
519+
/// replacement graph.
520+
pub(crate) fn node_status(
521+
&self,
522+
per_node @ PatchNode(commit_id, node): PatchNode,
523+
) -> NodeStatus {
524+
debug_assert!(self.contains_id(commit_id), "unknown commit");
525+
if self
526+
.get_commit(commit_id)
527+
.replacement()
528+
.is_some_and(|repl| repl.get_replacement_io().contains(&node))
529+
{
530+
NodeStatus::ReplacementIO
531+
} else if let Some(commit_id) = self.find_deleting_commit(per_node) {
532+
NodeStatus::Deleted(commit_id)
533+
} else {
534+
NodeStatus::Valid
535+
}
536+
}
485537
}
486538

487539
impl<'a> IntoIterator for &'a PersistentHugr {
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use std::collections::HashMap;
2+
3+
use crate::CommitId;
4+
5+
/// A cache for storing computed properties of a `PersistentHugr`.
6+
#[derive(Debug, Default, Clone)]
7+
pub(super) struct PersistentHugrCache {
8+
children_cache: HashMap<CommitId, Vec<CommitId>>,
9+
}
10+
11+
impl PersistentHugrCache {
12+
pub fn invalidate_children(&mut self, commit: CommitId) {
13+
self.children_cache.remove(&commit);
14+
}
15+
16+
pub fn children_or_insert(
17+
&mut self,
18+
commit: CommitId,
19+
children: impl FnOnce() -> Vec<CommitId>,
20+
) -> &Vec<CommitId> {
21+
self.children_cache.entry(commit).or_insert_with(children)
22+
}
23+
}

hugr-persistent/src/trait_impls.rs

Lines changed: 98 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use std::collections::HashMap;
1+
use std::{
2+
collections::{BTreeSet, HashMap, VecDeque},
3+
iter::FusedIterator,
4+
};
25

36
use itertools::{Either, Itertools};
47

@@ -16,7 +19,7 @@ use hugr_core::{
1619
ops::{OpTag, OpTrait, OpType},
1720
};
1821

19-
use crate::CommitId;
22+
use crate::{CommitId, persistent_hugr::NodeStatus};
2023

2124
use super::{
2225
InvalidCommit, PatchNode, PersistentHugr, PersistentReplacement, state_space::CommitData,
@@ -100,14 +103,18 @@ impl HugrView for PersistentHugr {
100103
}
101104

102105
fn get_parent(&self, node: Self::Node) -> Option<Self::Node> {
103-
assert!(self.contains_node(node), "invalid node");
104-
let (hugr, node_map) = self.apply_all();
105-
let parent = hugr.get_parent(node_map[&node])?;
106-
let parent_inv = node_map
107-
.iter()
108-
.find_map(|(&k, &v)| (v == parent).then_some(k))
109-
.expect("parent not found in node map");
110-
Some(parent_inv)
106+
debug_assert!(self.contains_node(node), "invalid node");
107+
108+
if node.owner() == self.base() {
109+
self.base_hugr()
110+
.get_parent(node.1)
111+
.map(|n| PatchNode(self.base(), n))
112+
} else {
113+
// all nodes in children commits are applied on the sibling DFG of the
114+
// entrypoint
115+
// TODO: generalise this for the case that commits introduce nested DFGs.
116+
Some(self.entrypoint())
117+
}
111118
}
112119

113120
fn get_optype(&self, PatchNode(commit_id, node): Self::Node) -> &OpType {
@@ -216,24 +223,16 @@ impl HugrView for PersistentHugr {
216223

217224
fn children(&self, node: Self::Node) -> impl DoubleEndedIterator<Item = Self::Node> + Clone {
218225
// Only children of dataflow parents may change
226+
let cm = self.get_commit(node.owner());
227+
let commit_hugr = cm.commit_hugr();
228+
let children = commit_hugr.children(node.1).map(|n| cm.to_patch_node(n));
219229
if OpTag::DataflowParent.is_superset(self.get_optype(node).tag()) {
220-
// TODO: make this more efficient by traversing from inputs to outputs
221-
let (hugr, node_map) = self.apply_all();
222-
let children = hugr.children(node_map[&node]).collect_vec();
223-
let inv_node_map: HashMap<_, _> = node_map.into_iter().map(|(k, v)| (v, k)).collect();
224-
Either::Right(children.into_iter().map(move |child| {
225-
*inv_node_map
226-
.get(&child)
227-
.expect("node not found in node map")
228-
}))
230+
// we must filter out children nodes that are invalidated by later commits, and
231+
// on the other hand add nodes in those commits
232+
Either::Left(IterValidNodes::new(self, children.fuse()))
229233
} else {
230-
// children are children of the commit hugr
231-
let cm = self.get_commit(node.owner());
232-
Either::Left(
233-
cm.commit_hugr()
234-
.children(node.1)
235-
.map(|n| cm.to_patch_node(n)),
236-
)
234+
// children are precisely children of the commit hugr
235+
Either::Right(children)
237236
}
238237
}
239238

@@ -341,6 +340,79 @@ impl HugrView for PersistentHugr {
341340
}
342341
}
343342

343+
/// An iterator over nodes in a `PersistentHugr` that filters out invalid nodes.
344+
///
345+
/// For any invalid node encountered, it will traverse and return the nodes in
346+
/// the commit deleting the node instead.
347+
#[derive(Debug, Clone)]
348+
pub struct IterValidNodes<'a, I> {
349+
/// The original iterator over nodes.
350+
nodes_iter: I,
351+
/// Nodes discovered in commits deleting nodes in the original iterator.
352+
discovered_nodes: VecDeque<PatchNode>,
353+
/// Commits discovered that delete nodes in the original iterator.
354+
discovered_commits: VecDeque<CommitId>,
355+
/// Commits discovered across all time, to make sure we only process each
356+
/// commit once.
357+
processed_commits: BTreeSet<CommitId>,
358+
/// The persistent hugr that the nodes belong to.
359+
hugr: &'a PersistentHugr,
360+
}
361+
362+
impl<'a, I> IterValidNodes<'a, I> {
363+
fn new(hugr: &'a PersistentHugr, nodes_iter: impl IntoIterator<IntoIter = I>) -> Self {
364+
Self {
365+
nodes_iter: nodes_iter.into_iter(),
366+
discovered_nodes: VecDeque::new(),
367+
discovered_commits: VecDeque::new(),
368+
processed_commits: BTreeSet::new(),
369+
hugr,
370+
}
371+
}
372+
}
373+
374+
impl<I: FusedIterator<Item = PatchNode>> Iterator for IterValidNodes<'_, I> {
375+
type Item = PatchNode;
376+
377+
fn next(&mut self) -> Option<Self::Item> {
378+
loop {
379+
let Some(node) = self
380+
.nodes_iter
381+
.next()
382+
.or_else(|| self.discovered_nodes.pop_front())
383+
else {
384+
break;
385+
};
386+
match self.hugr.node_status(node) {
387+
NodeStatus::Deleted(commit_id) => {
388+
if self.processed_commits.insert(commit_id) {
389+
self.discovered_commits.push_back(commit_id);
390+
}
391+
}
392+
NodeStatus::ReplacementIO | NodeStatus::Valid => return Some(node),
393+
}
394+
}
395+
396+
// Add nodes in next commit to queue
397+
let next_commit_id = self.discovered_commits.pop_front()?;
398+
let next_commit = self.hugr.get_commit(next_commit_id);
399+
400+
self.discovered_nodes.extend(
401+
next_commit
402+
.inserted_nodes()
403+
.map(|n| next_commit.to_patch_node(n)),
404+
);
405+
406+
self.next()
407+
}
408+
}
409+
410+
impl<I: FusedIterator<Item = PatchNode>> DoubleEndedIterator for IterValidNodes<'_, I> {
411+
fn next_back(&mut self) -> Option<Self::Item> {
412+
unimplemented!("cannot go backwards")
413+
}
414+
}
415+
344416
#[cfg(test)]
345417
mod tests {
346418
use std::collections::HashSet;

hugr-persistent/src/wire.rs

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use hugr_core::{
66
};
77
use itertools::Itertools;
88

9-
use crate::{CommitId, PatchNode, PersistentHugr, Walker};
9+
use crate::{CommitId, PatchNode, PersistentHugr, Walker, persistent_hugr::NodeStatus};
1010

1111
/// A wire in a [`PersistentHugr`].
1212
///
@@ -59,43 +59,11 @@ impl CommitWire {
5959
}
6060
}
6161

62-
/// A node in a commit of a [`PersistentHugr`] is either a valid node of the
63-
/// HUGR, a node deleted by a child commit in that [`PersistentHugr`], or an
64-
/// input or output node in a replacement graph.
65-
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
66-
enum NodeStatus {
67-
/// A node deleted by a child commit in that [`PersistentHugr`].
68-
///
69-
/// The ID of the child commit is stored in the variant.
70-
Deleted(CommitId),
71-
/// An input or output node in the replacement graph of a Commit
72-
ReplacementIO,
73-
/// A valid node in the [`PersistentHugr`]
74-
Valid,
75-
}
76-
7762
impl PersistentHugr {
7863
pub fn get_wire(&self, node: PatchNode, port: impl Into<Port>) -> PersistentWire {
7964
PersistentWire::from_port(node, port, self)
8065
}
8166

82-
/// Whether a node is valid in `self`, is deleted or is an IO node in a
83-
/// replacement graph.
84-
fn node_status(&self, per_node @ PatchNode(commit_id, node): PatchNode) -> NodeStatus {
85-
debug_assert!(self.contains_id(commit_id), "unknown commit");
86-
if self
87-
.get_commit(commit_id)
88-
.replacement()
89-
.is_some_and(|repl| repl.get_replacement_io().contains(&node))
90-
{
91-
NodeStatus::ReplacementIO
92-
} else if let Some(commit_id) = self.find_deleting_commit(per_node) {
93-
NodeStatus::Deleted(commit_id)
94-
} else {
95-
NodeStatus::Valid
96-
}
97-
}
98-
9967
/// The unique outgoing port in `self` that `port` is attached to.
10068
///
10169
/// # Panics
@@ -132,7 +100,7 @@ impl PersistentHugr {
132100
impl PersistentWire {
133101
/// Get the wire connected to a specified port of a pinned node in `hugr`.
134102
fn from_port(node: PatchNode, port: impl Into<Port>, per_hugr: &PersistentHugr) -> Self {
135-
assert!(per_hugr.contains_node(node), "node not in hugr");
103+
debug_assert!(per_hugr.contains_node(node), "node not in hugr");
136104

137105
// Queue of wires within each commit HUGR, that combined will form the
138106
// persistent wire.

0 commit comments

Comments
 (0)