Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hugr-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod extension;
pub mod hugr;
pub mod import;
pub mod macros;
pub mod module_graph;
pub mod ops;
pub mod package;
pub mod std_extensions;
Expand Down
210 changes: 210 additions & 0 deletions hugr-core/src/module_graph.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
//! Data structure summarizing static nodes of a Hugr and their uses
use std::collections::HashMap;

use crate::{HugrView, Node, core::HugrNode, ops::OpType};
use petgraph::{Graph, visit::EdgeRef};

/// Weight for an edge in a [`ModuleGraph`]
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum StaticEdge<N = Node> {
/// Edge corresponds to a [Call](OpType::Call) node (specified) in the Hugr
Call(N),
/// Edge corresponds to a [`LoadFunction`](OpType::LoadFunction) node (specified) in the Hugr
LoadFunction(N),
/// Edge corresponds to a [LoadConstant](OpType::LoadConstant) node (specified) in the Hugr
LoadConstant(N),
}

/// Weight for a petgraph-node in a [`ModuleGraph`]
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum StaticNode<N = Node> {
/// petgraph-node corresponds to a [`FuncDecl`](OpType::FuncDecl) node (specified) in the Hugr
FuncDecl(N),
/// petgraph-node corresponds to a [`FuncDefn`](OpType::FuncDefn) node (specified) in the Hugr
FuncDefn(N),
/// petgraph-node corresponds to the [HugrView::entrypoint], that is not
/// a [`FuncDefn`](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module)
/// either, as such a node could not have edges, so is not represented in the petgraph.
NonFuncEntrypoint,
/// petgraph-node corresponds to a constant; will have no outgoing edges, and incoming
/// edges will be [StaticEdge::LoadConstant]
Const(N),
}

/// Details the [`FuncDefn`]s, [`FuncDecl`]s and module-level [`Const`]s in a Hugr,
/// in a Hugr, along with the [`Call`]s, [`LoadFunction`]s, and [`LoadConstant`]s connecting them.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// in a Hugr, along with the [`Call`]s, [`LoadFunction`]s, and [`LoadConstant`]s connecting them.
/// along with the [`Call`]s, [`LoadFunction`]s, and [`LoadConstant`]s connecting them.

///
/// Each node in the `ModuleGraph` corresponds to a module-level function or const;
/// each edge corresponds to a use of the target contained in the edge's source.
///
/// For Hugrs whose entrypoint is neither a [Module](OpType::Module) nor a [`FuncDefn`],
/// the static graph will have an additional [`StaticNode::NonFuncEntrypoint`]
/// corresponding to the Hugr's entrypoint, with no incoming edges.
///
/// [`Call`]: OpType::Call
/// [`Const`]: OpType::Const
/// [`FuncDecl`]: OpType::FuncDecl
/// [`FuncDefn`]: OpType::FuncDefn
/// [`LoadConstant`]: OpType::LoadConstant
/// [`LoadFunction`]: OpType::LoadFunction
pub struct ModuleGraph<N = Node> {
g: Graph<StaticNode<N>, StaticEdge<N>>,
node_to_g: HashMap<N, petgraph::graph::NodeIndex<u32>>,
}

impl<N: HugrNode> ModuleGraph<N> {
/// Makes a new `ModuleGraph` for a Hugr.
pub fn new(hugr: &impl HugrView<Node = N>) -> Self {
let mut g = Graph::default();
let mut node_to_g = hugr
.children(hugr.module_root())
.filter_map(|n| {
let weight = match hugr.get_optype(n) {
OpType::FuncDecl(_) => StaticNode::FuncDecl(n),
OpType::FuncDefn(_) => StaticNode::FuncDefn(n),
OpType::Const(_) => StaticNode::Const(n),
_ => return None,
};
Some((n, g.add_node(weight)))
})
.collect::<HashMap<_, _>>();
if !hugr.entrypoint_optype().is_module() && !node_to_g.contains_key(&hugr.entrypoint()) {
node_to_g.insert(hugr.entrypoint(), g.add_node(StaticNode::NonFuncEntrypoint));
}
for (func, cg_node) in &node_to_g {
traverse(hugr, *cg_node, *func, &mut g, &node_to_g);
}
fn traverse<N: HugrNode>(
h: &impl HugrView<Node = N>,
enclosing_func: petgraph::graph::NodeIndex<u32>,
node: N, // Nonstrict-descendant of `enclosing_func``
g: &mut Graph<StaticNode<N>, StaticEdge<N>>,
node_to_g: &HashMap<N, petgraph::graph::NodeIndex<u32>>,
) {
for ch in h.children(node) {
traverse(h, enclosing_func, ch, g, node_to_g);
let weight = match h.get_optype(ch) {
OpType::Call(_) => StaticEdge::Call(ch),
OpType::LoadFunction(_) => StaticEdge::LoadFunction(ch),
OpType::LoadConstant(_) => StaticEdge::LoadConstant(ch),
_ => continue,
};
if let Some(target) = h.static_source(ch) {
if h.get_parent(target) == Some(h.module_root()) {
g.add_edge(enclosing_func, node_to_g[&target], weight);
} else {
assert!(!node_to_g.contains_key(&target));
assert!(h.get_optype(ch).is_load_constant());
assert!(h.get_optype(target).is_const());
}
}
}
}
ModuleGraph { g, node_to_g }
}

/// Allows access to the petgraph
#[must_use]
pub fn graph(&self) -> &Graph<StaticNode<N>, StaticEdge<N>> {
&self.g
}

/// Convert a Hugr [Node] into a petgraph node index.
/// Result will be `None` if `n` is not a [`FuncDefn`](OpType::FuncDefn),
/// [`FuncDecl`](OpType::FuncDecl) or the [HugrView::entrypoint].
pub fn node_index(&self, n: N) -> Option<petgraph::graph::NodeIndex<u32>> {
self.node_to_g.get(&n).copied()
}

/// Returns an iterator over the out-edges from the given Node, i.e.
/// edges to the functions/constants called/loaded by it.
///
/// If the node is not recognised as a function or the entrypoint,
/// for example if it is a [`Const`](OpType::Const), the iterator will be empty.
pub fn out_edges(&self, n: N) -> impl Iterator<Item = (&StaticEdge<N>, &StaticNode<N>)> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also want an in_edges method?

let g = self.graph();
self.node_index(n).into_iter().flat_map(move |n| {
self.graph().edges(n).map(|e| {
(
g.edge_weight(e.id()).unwrap(),
g.node_weight(e.target()).unwrap(),
)
})
})
}

/// Returns an iterator over the in-edges to the given Node, i.e.
/// edges from the (necessarily) functions that call/load it.
///
/// If the node is not recognised as a function or constant,
/// for example if it is a non-function entrypoint, the iterator will be empty.
pub fn in_edges(&self, n: N) -> impl Iterator<Item = (&StaticNode<N>, &StaticEdge<N>)> {
let g = self.graph();
self.node_index(n).into_iter().flat_map(move |n| {
self.graph()
.edges_directed(n, petgraph::Direction::Incoming)
.map(|e| {
(
g.node_weight(e.source()).unwrap(),
g.edge_weight(e.id()).unwrap(),
)
})
})
}
}

#[cfg(test)]
mod test {
use itertools::Itertools as _;

use crate::builder::{
Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, endo_sig, inout_sig,
};
use crate::extension::prelude::{ConstUsize, usize_t};
use crate::ops::{Value, handle::NodeHandle};

use super::*;

#[test]
fn edges() {
let mut mb = ModuleBuilder::new();
let cst = mb.add_constant(Value::from(ConstUsize::new(42)));
let callee = mb.define_function("callee", endo_sig(usize_t())).unwrap();
let ins = callee.input_wires();
let callee = callee.finish_with_outputs(ins).unwrap();
let mut caller = mb
.define_function("caller", inout_sig(vec![], usize_t()))
.unwrap();
let val = caller.load_const(&cst);
let call = caller.call(callee.handle(), &[], vec![val]).unwrap();
let caller = caller.finish_with_outputs(call.outputs()).unwrap();
let h = mb.finish_hugr().unwrap();

let mg = ModuleGraph::new(&h);
let call_edge = StaticEdge::Call(call.node());
let load_const_edge = StaticEdge::LoadConstant(val.node());

assert_eq!(mg.out_edges(callee.node()).next(), None);
assert_eq!(
mg.in_edges(callee.node()).collect_vec(),
[(&StaticNode::FuncDefn(caller.node()), &call_edge,)]
);

assert_eq!(
mg.out_edges(caller.node()).collect_vec(),
[
(&call_edge, &StaticNode::FuncDefn(callee.node()),),
(&load_const_edge, &StaticNode::Const(cst.node()),)
]
);
assert_eq!(mg.in_edges(caller.node()).next(), None);

assert_eq!(mg.out_edges(cst.node()).next(), None);
assert_eq!(
mg.in_edges(cst.node()).collect_vec(),
[(&StaticNode::FuncDefn(caller.node()), &load_const_edge,)]
);
}
}
2 changes: 1 addition & 1 deletion hugr-passes/src/dead_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub enum PreserveNode {
impl PreserveNode {
/// A conservative default for a given node. Just examines the node's [`OpType`]:
/// * Assumes all Calls must be preserved. (One could scan the called `FuncDefn`, but would
/// also need to check for cycles in the [`CallGraph`](super::call_graph::CallGraph).)
/// also need to check for cycles in the [`ModuleGraph`](hugr_core::module_graph::ModuleGraph).)
/// * Assumes all CFGs must be preserved. (One could, for example, allow acyclic
/// CFGs to be removed.)
/// * Assumes all `TailLoops` must be preserved. (One could, for example, use dataflow
Expand Down
15 changes: 8 additions & 7 deletions hugr-passes/src/dead_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::collections::HashSet;
use hugr_core::{
HugrView, Node,
hugr::hugrmut::HugrMut,
module_graph::{ModuleGraph, StaticNode},
ops::{OpTag, OpTrait},
};
use petgraph::visit::{Dfs, Walker};
Expand All @@ -14,8 +15,6 @@ use crate::{
composable::{ValidatePassError, validate_if_test},
};

use super::call_graph::{CallGraph, CallGraphNode};

#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
/// Errors produced by [`RemoveDeadFuncsPass`].
Expand All @@ -31,7 +30,7 @@ pub enum RemoveDeadFuncsError<N = Node> {
}

fn reachable_funcs<'a, H: HugrView>(
cg: &'a CallGraph<H::Node>,
cg: &'a ModuleGraph<H::Node>,
h: &'a H,
entry_points: impl IntoIterator<Item = H::Node>,
) -> impl Iterator<Item = H::Node> + 'a {
Expand All @@ -41,9 +40,11 @@ fn reachable_funcs<'a, H: HugrView>(
for n in entry_points {
d.stack.push(cg.node_index(n).unwrap());
}
d.iter(g).map(|i| match g.node_weight(i).unwrap() {
CallGraphNode::FuncDefn(n) | CallGraphNode::FuncDecl(n) => *n,
CallGraphNode::NonFuncRoot => h.entrypoint(),
d.iter(g).filter_map(|i| match g.node_weight(i).unwrap() {
StaticNode::FuncDefn(n) | StaticNode::FuncDecl(n) => Some(*n),
StaticNode::NonFuncEntrypoint => Some(h.entrypoint()),
StaticNode::Const(_) => None,
_ => unreachable!(),
})
}

Expand Down Expand Up @@ -85,7 +86,7 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
}

let mut reachable =
reachable_funcs(&CallGraph::new(hugr), hugr, entry_points).collect::<HashSet<_>>();
reachable_funcs(&ModuleGraph::new(hugr), hugr, entry_points).collect::<HashSet<_>>();
// Also prevent removing the entrypoint itself
let mut n = Some(hugr.entrypoint());
while let Some(n2) = n {
Expand Down
38 changes: 16 additions & 22 deletions hugr-passes/src/inline_funcs.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
//! Contains a pass to inline calls to selected functions in a Hugr.
use std::collections::{HashSet, VecDeque};

use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::hugr::patch::inline_call::InlineCall;
use itertools::Itertools;
use petgraph::algo::tarjan_scc;

use crate::call_graph::{CallGraph, CallGraphNode};
use hugr_core::hugr::{hugrmut::HugrMut, patch::inline_call::InlineCall};
use hugr_core::module_graph::{ModuleGraph, StaticNode};

/// Error raised by [inline_acyclic]
#[derive(Clone, Debug, thiserror::Error, PartialEq)]
Expand All @@ -26,7 +25,7 @@ pub fn inline_acyclic<H: HugrMut>(
h: &mut H,
call_predicate: impl Fn(&H, H::Node) -> bool,
) -> Result<(), InlineFuncsError> {
let cg = CallGraph::new(&*h);
let cg = ModuleGraph::new(&*h);
let g = cg.graph();
let all_funcs_in_cycles = tarjan_scc(g)
.into_iter()
Expand All @@ -37,7 +36,7 @@ pub fn inline_acyclic<H: HugrMut>(
}
}
ns.into_iter().map(|n| {
let CallGraphNode::FuncDefn(fd) = g.node_weight(n).unwrap() else {
let StaticNode::FuncDefn(fd) = g.node_weight(n).unwrap() else {
panic!("Expected only FuncDefns in sccs")
};
*fd
Expand Down Expand Up @@ -68,18 +67,17 @@ pub fn inline_acyclic<H: HugrMut>(
mod test {
use std::collections::HashSet;

use hugr_core::core::HugrNode;
use hugr_core::ops::OpType;
use itertools::Itertools;
use petgraph::visit::EdgeRef;
use rstest::rstest;

use hugr_core::HugrView;
use hugr_core::builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
use hugr_core::core::HugrNode;
use hugr_core::module_graph::{ModuleGraph, StaticNode};
use hugr_core::ops::OpType;
use hugr_core::{Hugr, extension::prelude::qb_t, types::Signature};
use rstest::rstest;

use crate::call_graph::{CallGraph, CallGraphNode};
use crate::inline_funcs::inline_acyclic;
use super::inline_acyclic;

/// /->-\
/// main -> f g -> b -> c
Expand Down Expand Up @@ -156,7 +154,7 @@ mod test {
target_funcs.contains(&tgt)
})
.unwrap();
let cg = CallGraph::new(&h);
let cg = ModuleGraph::new(&h);
for fname in check_not_called {
let fnode = find_func(&h, fname);
let fnode = cg.node_index(fnode).unwrap();
Expand All @@ -180,12 +178,8 @@ mod test {
}
}

fn outgoing_calls<N: HugrNode>(cg: &CallGraph<N>, src: N) -> Vec<N> {
let src = cg.node_index(src).unwrap();
cg.graph()
.edges_directed(src, petgraph::Direction::Outgoing)
.map(|e| func_node(cg.graph().node_weight(e.target()).unwrap()))
.collect()
fn outgoing_calls<N: HugrNode>(cg: &ModuleGraph<N>, src: N) -> Vec<N> {
cg.out_edges(src).map(|(_, tgt)| func_node(tgt)).collect()
}

#[test]
Expand All @@ -205,17 +199,17 @@ mod test {
}
})
.unwrap();
let cg = CallGraph::new(&h);
let cg = ModuleGraph::new(&h);
// b and then c should have been inlined into g, leaving only cyclic call to f
assert_eq!(outgoing_calls(&cg, g), [find_func(&h, "f")]);
// But c should not have been inlined into b:
assert_eq!(outgoing_calls(&cg, b), [c]);
}

fn func_node<N: Copy>(cgn: &CallGraphNode<N>) -> N {
fn func_node<N: Copy>(cgn: &StaticNode<N>) -> N {
match cgn {
CallGraphNode::FuncDecl(n) | CallGraphNode::FuncDefn(n) => *n,
CallGraphNode::NonFuncRoot => panic!(),
StaticNode::FuncDecl(n) | StaticNode::FuncDefn(n) => *n,
_ => panic!(),
}
}

Expand Down
1 change: 1 addition & 0 deletions hugr-passes/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Compilation passes acting on the HUGR program representation.

#[deprecated(note = "Use hugr-core::module_graph::ModuleGraph", since = "0.24.1")]
pub mod call_graph;
pub mod composable;
pub use composable::ComposablePass;
Expand Down
Loading