diff --git a/hugr-core/src/hugr/linking.rs b/hugr-core/src/hugr/linking.rs index f20c006e6..4c84cb5fc 100644 --- a/hugr-core/src/hugr/linking.rs +++ b/hugr-core/src/hugr/linking.rs @@ -1,14 +1,12 @@ //! Directives and errors relating to linking Hugrs. -use std::{collections::HashMap, fmt::Display}; +use std::collections::{BTreeMap, HashMap}; +use std::fmt::Display; -use itertools::Either; +use itertools::{Either, Itertools}; -use crate::{ - Hugr, HugrView, Node, - core::HugrNode, - hugr::{HugrMut, hugrmut::InsertedForest, internal::HugrMutInternals}, -}; +use crate::hugr::{HugrMut, hugrmut::InsertedForest, internal::HugrMutInternals}; +use crate::{Hugr, HugrView, Node, Visibility, core::HugrNode, ops::OpType, types::PolyFuncType}; /// Methods that merge Hugrs, adding static edges between old and inserted nodes. /// @@ -105,6 +103,70 @@ pub trait HugrLinking: HugrMut { link_by_node(self, transfers, &mut inserted.node_map); Ok(inserted) } + + /// Insert module-children from another Hugr into this one according to a [NameLinkingPolicy]. + /// + /// All [Visibility::Public] module-children are inserted, or linked, according to the + /// specified policy; private children will also be inserted, at least including all those + /// used by the copied public children. (At present all module-children are inserted, + /// but this is expected to change in the future.) + /// + /// The entrypoints of both `self` and `other` are ignored. + /// + /// # Errors + /// + /// If [NameLinkingPolicy::on_signature_conflict] or [NameLinkingPolicy::on_multiple_defn] + /// are set to [OnNewFunc::RaiseError], and the respective conflict occurs between + /// `self` and `other`. + /// + /// [Visibility::Public]: crate::Visibility::Public + /// [FuncDefn]: crate::ops::FuncDefn + fn link_module( + &mut self, + other: Hugr, + policy: &NameLinkingPolicy, + ) -> Result, NameLinkingError> { + let actions = policy.link_actions(self, &other)?; + let directives = actions + .into_iter() + .map(|(k, LinkAction::LinkNode(d))| (k, d)) + .collect(); + Ok(self + .insert_link_hugr_by_node(None, other, directives) + .expect("NodeLinkingPolicy was constructed to avoid any error")) + } + + /// Copy module-children from another Hugr into this one according to a [NameLinkingPolicy]. + /// + /// All [Visibility::Public] module-children are copied, or linked, according to the + /// specified policy; private children will also be copied, at least including all those + /// used by the copied public children. (At present all module-children are inserted, + /// but this is expected to change in the future.) + /// + /// The entrypoints of both `self` and `other` are ignored. + /// + /// # Errors + /// + /// If [NameLinkingPolicy::on_signature_conflict] or [NameLinkingPolicy::on_multiple_defn] + /// are set to [OnNewFunc::RaiseError], and the respective conflict occurs between + /// `self` and `other`. + /// + /// [Visibility::Public]: crate::Visibility::Public + /// [FuncDefn]: crate::ops::FuncDefn + fn link_module_view( + &mut self, + other: &impl HugrView, + policy: &NameLinkingPolicy, + ) -> Result, NameLinkingError> { + let actions = policy.link_actions(self, &other)?; + let directives = actions + .into_iter() + .map(|(k, LinkAction::LinkNode(d))| (k, d)) + .collect(); + Ok(self + .insert_link_view_by_node(None, other, directives) + .expect("NodeLinkingPolicy was constructed to avoid any error")) + } } impl HugrLinking for T {} @@ -185,14 +247,290 @@ impl NodeLinkingDirective { } } +/// Describes how to link two Hugrs (a source Hugr bing inserted into a target Hugr), +/// abstracted from any specific Hugrs. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct NameLinkingPolicy { + sig_conflict: OnNewFunc, + multi_defn: OnMultiDefn, +} + +/// Specifies what to do with a function in some situation - used in +/// * [NameLinkingPolicy::on_signature_conflict] +/// * [OnMultiDefn::NewFunc] +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +#[non_exhaustive] // could consider e.g. disconnections +pub enum OnNewFunc { + /// Do not link the Hugrs together; fail with a [NameLinkingError] instead. + RaiseError, + /// Add the new function alongside the existing one in the target Hugr, + /// preserving (separately) uses of both. (The Hugr will be invalid because + /// of [duplicate names](crate::hugr::ValidationError::DuplicateExport).) + Add, +} + +/// What to do when both target and inserted Hugr +/// have a [Public] [FuncDefn] with the same name and signature. +/// +/// [FuncDefn]: crate::ops::FuncDefn +/// [Public]: crate::Visibility::Public +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, derive_more::From)] +#[non_exhaustive] // could consider e.g. disconnections +pub enum OnMultiDefn { + /// Keep the implementation already in the target Hugr. (Edges in the source + /// Hugr will be redirected to use the function from the target.) + UseTarget, + /// Keep the implementation in the source Hugr. (Edges in the target Hugr + /// will be redirected to use the function from the source; the previously-existing + /// function in the target Hugr will be removed.) + UseSource, + /// Proceed as per the specified [OnNewFunc]. + NewFunc(#[from] OnNewFunc), +} + +/// An error in using names to determine how to link functions in source and target Hugrs. +/// (SN = Source Node, TN = Target Node) +#[derive(Clone, Debug, thiserror::Error, PartialEq)] +#[non_exhaustive] +pub enum NameLinkingError { + /// Both source and target contained a [FuncDefn] (public and with same name + /// and signature). + /// + /// [FuncDefn]: crate::ops::FuncDefn + #[error("Source ({_1}) and target ({_2}) both contained FuncDefn with same public name {_0}")] + MultipleDefn(String, SN, TN), + /// Source and target containing public functions with conflicting signatures + #[error( + "Conflicting signatures for name {name} - Source ({src_node}) has {src_sig}, Target ({tgt_node}) has ({tgt_sig})" + )] + #[allow(missing_docs)] + SignatureConflict { + name: String, + src_node: SN, + src_sig: Box, + tgt_node: TN, + tgt_sig: Box, + }, + /// A [Visibility::Public] function in the source, whose body is being added + /// to the target, contained the entrypoint (which needs to be added + /// in a different place). + /// + /// [Visibility::Public]: crate::Visibility::Public + #[error("The entrypoint is contained within function {_0} which will be added as {_1:?}")] + AddFunctionContainingEntrypoint(SN, NodeLinkingDirective), +} + +impl NameLinkingPolicy { + /// Makes a new instance that specifies to keep both decls/defns when (for the same name) + /// they have different signatures or when both are defns. Thus, an error is never raised; + /// a (potentially-invalid) Hugr is always produced. + pub fn new_keep_both_invalid() -> Self { + Self { + multi_defn: OnMultiDefn::NewFunc(OnNewFunc::Add), + sig_conflict: OnNewFunc::Add, + } + } + + /// Sets how to behave when both target and inserted Hugr have a + /// ([Visibility::Public]) function with the same name but different signatures. + /// + /// See [Self::get_on_signature_conflict]. + pub fn on_signature_conflict(mut self, sc: OnNewFunc) -> Self { + self.sig_conflict = sc; + self + } + + /// Returns how this policy will behave when both target and inserted Hugr have a + /// ([Visibility::Public]) function with the same name but different signatures. + /// + /// Can be changed via [Self::on_signature_conflict]. + pub fn get_on_signature_conflict(&self) -> OnNewFunc { + self.sig_conflict + } + + /// Sets how to behave when both target and inserted Hugr have a + /// [FuncDefn](crate::ops::FuncDefn) with the same name and signature. + /// + /// See [Self::get_on_multiple_defn]. + pub fn on_multiple_defn(mut self, multi_defn: OnMultiDefn) -> Self { + self.multi_defn = multi_defn; + self + } + + /// Returns how this policy will behave when both target and inserted Hugr have a + /// [FuncDefn](crate::ops::FuncDefn) with the same name and signature. + /// + /// Can be changed via [Self::on_multiple_defn]. + pub fn get_on_multiple_defn(&self) -> OnMultiDefn { + self.multi_defn + } + + /// Computes concrete actions to link a specific source (inserted) and target + /// (host) Hugr according to this policy. + pub fn link_actions( + &self, + target: &(impl HugrView + ?Sized), + source: &impl HugrView, + ) -> Result, NameLinkingError> { + let existing = gather_existing(target); + let mut res = LinkActions::new(); + + let NameLinkingPolicy { + sig_conflict, + multi_defn, + } = self; + for n in source.children(source.module_root()) { + let dirv = match link_sig(source, n) { + None => continue, + Some(LinkSig::Private) => NodeLinkingDirective::add(), + Some(LinkSig::Public { name, is_defn, sig }) => { + if let Some((ex_ns, ex_sig)) = existing.get(name) { + match *sig_conflict { + _ if sig == *ex_sig => directive(name, n, is_defn, ex_ns, multi_defn)?, + OnNewFunc::RaiseError => { + return Err(NameLinkingError::SignatureConflict { + name: name.to_string(), + src_node: n, + src_sig: Box::new(sig.clone()), + tgt_node: target_node(ex_ns), + tgt_sig: Box::new((*ex_sig).clone()), + }); + } + OnNewFunc::Add => NodeLinkingDirective::add(), + } + } else { + NodeLinkingDirective::add() + } + } + }; + res.insert(n, dirv.into()); + } + + Ok(res) + } +} + +impl Default for NameLinkingPolicy { + fn default() -> Self { + Self { + sig_conflict: OnNewFunc::RaiseError, + multi_defn: OnNewFunc::RaiseError.into(), + } + } +} + +fn directive( + name: &str, + new_n: SN, + new_defn: bool, + ex_ns: &Either)>, + multi_defn: &OnMultiDefn, +) -> Result, NameLinkingError> { + Ok(match (new_defn, ex_ns) { + (false, Either::Right(_)) => NodeLinkingDirective::add(), // another alias + (false, Either::Left(defn)) => NodeLinkingDirective::UseExisting(*defn), // resolve decl + (true, Either::Right((decl, decls))) => { + NodeLinkingDirective::replace(std::iter::once(decl).chain(decls).cloned()) + } + (true, &Either::Left(defn)) => match multi_defn { + OnMultiDefn::UseTarget => NodeLinkingDirective::UseExisting(defn), + OnMultiDefn::UseSource => NodeLinkingDirective::replace([defn]), + OnMultiDefn::NewFunc(OnNewFunc::RaiseError) => { + return Err(NameLinkingError::MultipleDefn(name.to_owned(), new_n, defn)); + } + OnMultiDefn::NewFunc(OnNewFunc::Add) => NodeLinkingDirective::add(), + }, + }) +} + +type PubFuncs<'a, N> = (Either)>, &'a PolyFuncType); + +fn target_node(ns: &Either)>) -> N { + *ns.as_ref().left_or_else(|(n, _)| n) +} + +enum LinkSig<'a> { + Private, + Public { + name: &'a str, + is_defn: bool, + sig: &'a PolyFuncType, + }, +} + +fn link_sig(h: &H, n: H::Node) -> Option> { + let (name, is_defn, vis, sig) = match h.get_optype(n) { + OpType::FuncDecl(fd) => (fd.func_name(), false, fd.visibility(), fd.signature()), + OpType::FuncDefn(fd) => (fd.func_name(), true, fd.visibility(), fd.signature()), + OpType::Const(_) => return Some(LinkSig::Private), + _ => return None, + }; + Some(match vis { + Visibility::Public => LinkSig::Public { name, is_defn, sig }, + Visibility::Private => LinkSig::Private, + }) +} + +fn gather_existing<'a, H: HugrView + ?Sized>(h: &'a H) -> HashMap<&'a str, PubFuncs<'a, H::Node>> { + let left_if = |b| if b { Either::Left } else { Either::Right }; + h.children(h.module_root()) + .filter_map(|n| { + link_sig(h, n).and_then(|link_sig| match link_sig { + LinkSig::Public { name, is_defn, sig } => Some((name, (left_if(is_defn)(n), sig))), + LinkSig::Private => None, + }) + }) + .into_grouping_map() + .aggregate(|acc: Option>, name, (new, sig2)| { + let Some((mut acc, sig1)) = acc else { + return Some((new.map_right(|n| (n, vec![])), sig2)); + }; + assert_eq!( + sig1, sig2, + "Invalid Hugr: different signatures for {}", + name + ); + match (&mut acc, new) { + (Either::Right((_, decls)), Either::Right(ndecl)) => decls.push(ndecl), + (Either::Left(_), Either::Left(_)) => { + panic!("Invalid Hugr: Multiple FuncDefns for {name}") + } + _ => panic!("Invalid Hugr: FuncDefn and FuncDecl(s) for {name}"), + }; + Some((acc, sig2)) + }) +} + /// Details, node-by-node, how module-children of a source Hugr should be inserted into a /// target Hugr. /// /// For use with [HugrLinking::insert_link_hugr_by_node] and [HugrLinking::insert_link_view_by_node]. pub type NodeLinkingDirectives = HashMap>; -/// Invariant: no SourceNode can be in both maps (by type of [NodeLinkingDirective]) -/// TargetNodes can be (in RHS of multiple directives) +/// Details a concrete action to link a specific node from source Hugr into a specific target Hugr. +/// +/// A separate enum from [NodeLinkingDirective] to allow [NameLinkingPolicy::link_actions] +/// to (eventually) specify a greater range of actions than that supported by +/// [HugrLinking::insert_link_hugr_by_node] and [HugrLinking::insert_link_view_by_node]. +/// (For example, to add a function but change it to private.) +#[derive(Clone, Debug, Hash, PartialEq, Eq, derive_more::From)] +#[non_exhaustive] +pub enum LinkAction { + /// Just apply the specified [NodeLinkingDirective]. + LinkNode(#[from] NodeLinkingDirective), +} + +/// Details the concrete actions to link a specific source Hugr into a specific target Hugr. +/// +/// Computed from a [NameLinkingPolicy] and contains all actions required to implement +/// that policy for those specific Hugrs. +/// +/// [BTreeMap] is used to give deterministic ordering of printing for debugging; the order +/// of actions (that arise from any [NameLinkingPolicy]) should not affect the linking itself. +pub type LinkActions = BTreeMap>; + +/// Invariant: no `SourceNode` can be in both maps (by type of [NodeLinkingDirective]) +/// `TargetNode`s can be (in RHS of multiple directives) struct Transfers { use_existing: HashMap, replace: HashMap, @@ -287,12 +625,25 @@ mod test { use cool_asserts::assert_matches; use itertools::Itertools; + use rstest::rstest; - use super::{HugrLinking, NodeLinkingDirective, NodeLinkingError}; + use super::{ + HugrLinking, NameLinkingError, NameLinkingPolicy, NodeLinkingDirective, NodeLinkingError, + OnMultiDefn, OnNewFunc, + }; use crate::builder::test::{dfg_calling_defn_decl, simple_dfg_hugr}; - use crate::hugr::hugrmut::test::check_calls_defn_decl; - use crate::ops::{FuncDecl, OpTag, OpTrait, handle::NodeHandle}; - use crate::{HugrView, hugr::HugrMut, types::Signature}; + use crate::builder::{ + Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + ModuleBuilder, + }; + use crate::core::HugrNode; + use crate::extension::prelude::{ConstUsize, usize_t}; + use crate::hugr::ValidationError; + use crate::hugr::hugrmut::{HugrMut, test::check_calls_defn_decl}; + use crate::ops::{FuncDecl, OpTag, OpTrait, OpType, Value, handle::NodeHandle}; + use crate::std_extensions::arithmetic::int_ops::IntOpDef; + use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; + use crate::{Hugr, HugrView, Visibility, types::Signature}; #[test] fn test_insert_link_nodes_add() { @@ -522,4 +873,220 @@ mod test { assert_eq!(h.static_source(call), Some(defn)); } } + + fn list_decls_defns( + h: &impl HugrView, + ) -> (HashMap, HashMap) { + let mut decls = HashMap::new(); + let mut defns = HashMap::new(); + for n in h.children(h.module_root()) { + match h.get_optype(n) { + OpType::FuncDecl(fd) => decls.insert(n, fd.func_name().as_str()), + OpType::FuncDefn(fd) => defns.insert(n, fd.func_name().as_str()), + _ => None, + }; + } + (decls, defns) + } + + fn call_targets(h: &H) -> HashMap { + h.nodes() + .filter(|n| h.get_optype(*n).is_call()) + .map(|n| (n, h.static_source(n).unwrap())) + .collect() + } + + #[rstest] + fn combines_decls_defn( + #[values(OnNewFunc::RaiseError, OnNewFunc::Add)] sig_conflict: OnNewFunc, + #[values( + OnNewFunc::RaiseError.into(), + OnMultiDefn::UseSource, + OnMultiDefn::UseTarget, + OnNewFunc::Add.into() + )] + multi_defn: OnMultiDefn, + ) { + let i64_t = || INT_TYPES[6].to_owned(); + let foo_sig = Signature::new_endo(i64_t()); + let bar_sig = Signature::new(vec![i64_t(); 2], i64_t()); + let mut target = { + let mut fb = + FunctionBuilder::new_vis("foo", foo_sig.clone(), Visibility::Public).unwrap(); + let mut mb = fb.module_root_builder(); + let bar1 = mb.declare("bar", bar_sig.clone().into()).unwrap(); + let bar2 = mb.declare("bar", bar_sig.clone().into()).unwrap(); // alias + let [i] = fb.input_wires_arr(); + let [c] = fb.call(&bar1, &[], [i, i]).unwrap().outputs_arr(); + let r = fb.call(&bar2, &[], [i, c]).unwrap(); + let h = fb.finish_hugr_with_outputs(r.outputs()).unwrap(); + assert_eq!( + list_decls_defns(&h), + ( + HashMap::from([(bar1.node(), "bar"), (bar2.node(), "bar")]), + HashMap::from([(h.entrypoint(), "foo")]) + ) + ); + h + }; + + let inserted = { + let mut main_b = FunctionBuilder::new("main", Signature::new(vec![], i64_t())).unwrap(); + let mut mb = main_b.module_root_builder(); + let foo1 = mb.declare("foo", foo_sig.clone().into()).unwrap(); + let foo2 = mb.declare("foo", foo_sig.clone().into()).unwrap(); + let mut bar = mb + .define_function_vis("bar", bar_sig.clone(), Visibility::Public) + .unwrap(); + let res = bar + .add_dataflow_op(IntOpDef::iadd.with_log_width(6), bar.input_wires()) + .unwrap(); + let bar = bar.finish_with_outputs(res.outputs()).unwrap(); + let i = main_b.add_load_value(ConstInt::new_u(6, 257).unwrap()); + let c = main_b.call(&foo1, &[], [i]).unwrap(); + let r = main_b.call(&foo2, &[], c.outputs()).unwrap(); + let h = main_b.finish_hugr_with_outputs(r.outputs()).unwrap(); + assert_eq!( + list_decls_defns(&h), + ( + HashMap::from([(foo1.node(), "foo"), (foo2.node(), "foo")]), + HashMap::from([(h.entrypoint(), "main"), (bar.node(), "bar")]) + ) + ); + h + }; + + let pol = NameLinkingPolicy { + sig_conflict, + multi_defn, + }; + let mut target2 = target.clone(); + + target.link_module_view(&inserted, &pol).unwrap(); + target2.link_module(inserted, &pol).unwrap(); + for tgt in [target, target2] { + tgt.validate().unwrap(); + let (decls, defns) = list_decls_defns(&tgt); + assert_eq!(decls, HashMap::new()); + assert_eq!( + defns.values().copied().sorted().collect_vec(), + ["bar", "foo", "main"] + ); + let call_tgts = call_targets(&tgt); + for (defn, name) in defns { + if name != "main" { + // Defns now have two calls each (was one to each alias) + assert_eq!(call_tgts.values().filter(|tgt| **tgt == defn).count(), 2); + } + } + } + } + + #[rstest] + fn sig_conflict( + #[values(false, true)] host_defn: bool, + #[values(false, true)] inserted_defn: bool, + ) { + let mk_def_or_decl = |n, sig: Signature, defn| { + let mut mb = ModuleBuilder::new(); + let node = if defn { + let fb = mb.define_function_vis(n, sig, Visibility::Public).unwrap(); + let ins = fb.input_wires(); + fb.finish_with_outputs(ins).unwrap().node() + } else { + mb.declare(n, sig.into()).unwrap().node() + }; + (mb.finish_hugr().unwrap(), node) + }; + + let old_sig = Signature::new_endo(usize_t()); + let (orig_host, orig_fn) = mk_def_or_decl("foo", old_sig.clone(), host_defn); + let new_sig = Signature::new_endo(INT_TYPES[3].clone()); + let (inserted, inserted_fn) = mk_def_or_decl("foo", new_sig.clone(), inserted_defn); + + let pol = NameLinkingPolicy::default(); + let mut host = orig_host.clone(); + let res = host.link_module_view(&inserted, &pol); + assert_eq!(host, orig_host); // Did nothing + assert_eq!( + res.err(), + Some(NameLinkingError::SignatureConflict { + name: "foo".to_string(), + src_node: inserted_fn, + src_sig: Box::new(new_sig.into()), + tgt_node: orig_fn, + tgt_sig: Box::new(old_sig.into()) + }) + ); + + let pol = pol.on_signature_conflict(OnNewFunc::Add); + let node_map = host.link_module(inserted, &pol).unwrap().node_map; + assert_eq!( + host.validate(), + Err(ValidationError::DuplicateExport { + link_name: "foo".to_string(), + children: [orig_fn, node_map[&inserted_fn]] + }) + ); + } + + #[rstest] + #[case(OnMultiDefn::UseSource, vec![11])] + #[case(OnMultiDefn::UseTarget, vec![5])] + #[case(OnNewFunc::Add.into(), vec![5, 11])] + #[case(OnNewFunc::RaiseError.into(), vec![])] + fn impl_conflict(#[case] multi_defn: OnMultiDefn, #[case] expected: Vec) { + fn build_hugr(cst: u64) -> Hugr { + let mut mb = ModuleBuilder::new(); + let cst = mb.add_constant(Value::from(ConstUsize::new(cst))); + let mut fb = mb + .define_function_vis("foo", Signature::new(vec![], usize_t()), Visibility::Public) + .unwrap(); + let c = fb.load_const(&cst); + fb.finish_with_outputs([c]).unwrap(); + mb.finish_hugr().unwrap() + } + let backup = build_hugr(5); + let mut host = backup.clone(); + let inserted = build_hugr(11); + + let pol = NameLinkingPolicy::new_keep_both_invalid().on_multiple_defn(multi_defn); + let res = host.link_module(inserted, &pol); + if multi_defn == OnNewFunc::RaiseError.into() { + assert!(matches!(res, Err(NameLinkingError::MultipleDefn(n, _, _)) if n == "foo")); + assert_eq!(host, backup); + return; + } + res.unwrap(); + let val_res = host.validate(); + if multi_defn == OnNewFunc::Add.into() { + assert!( + matches!(val_res, Err(ValidationError::DuplicateExport { link_name, .. }) if link_name == "foo") + ); + } else { + val_res.unwrap(); + } + let func_consts = host + .children(host.module_root()) + .filter(|n| host.get_optype(*n).is_func_defn()) + .map(|n| { + host.children(n) + .filter_map(|ch| host.static_source(ch)) // LoadConstant's + .map(|c| host.get_optype(c).as_const().unwrap()) + .map(|c| c.get_custom_value::().unwrap().value()) + .exactly_one() + .ok() + .unwrap() + }) + .collect_vec(); + assert_eq!(func_consts, expected); + // At the moment we copy all the constants regardless of whether they are used: + let all_consts: Vec<_> = host + .children(host.module_root()) + .filter_map(|ch| host.get_optype(ch).as_const()) + .map(|c| c.get_custom_value::().unwrap().value()) + .sorted() + .collect(); + assert_eq!(all_consts, [5, 11]); + } }