- 
                Notifications
    You must be signed in to change notification settings 
- Fork 13
Closed
Description
hugr-qir will need to inline all functions to produce valid qir.
We already have a Callgraph in hugr-passes, which is required for inlining.
I suggest that the inlining pass take a set of Call nodes as input, and that if these calls contain a cycle we error. Otherwise inline all those calls. That way we don't dictate the inlining policty, and in particular hugr-qir can easily "inline everything"
Prototype, which is pre our callgraph and perhaps out of date in other ways:
use hugr_core::{
    extension::ExtensionRegistry,
    hugr::{
        hugrmut::HugrMut,
        views::{DescendantsGraph, ExtractHugr as _, HierarchyView},
        HugrError, Rewrite, ValidationError,
    },
    ops::{DataflowOpTrait as _, OpTrait, DFG},
    Direction, HugrView, Node,
};
use itertools::Itertools as _;
use petgraph::visit::EdgeRef as _;
use thiserror::Error;
use crate::validation::ValidationLevel;
#[derive(Debug, Clone, Default)]
/// TODO docs
pub struct InlinePass {
    validation: ValidationLevel,
}
impl InlinePass {
    /// Sets the validation level used before and after the pass is run
    pub fn validation_level(mut self, level: ValidationLevel) -> Self {
        self.validation = level;
        self
    }
    pub fn run(
        &self,
        hugr: &mut impl HugrMut,
        registry: &ExtensionRegistry,
    ) -> Result<(), Box<dyn std::error::Error>> {
        self.validation
            .run_validated_pass_mut(hugr, registry, |hugr, _| {
                let mut calls = {
                    let cg = CallGraph::new(hugr);
                    let Some(calls) = cg.iter_nonrecursive() else {
                        Err("InlinePass: recursion")?
                    };
                    let mut calls = calls.collect_vec();
                    calls.reverse();
                    calls
                };
                // dbg!(&calls);
                let rewrites = calls
                    .iter()
                    .filter_map(|(caller, _)| InlineRewrite::try_new(hugr, *caller, registry).ok())
                    .collect_vec();
                for rewrite in rewrites {
                    hugr.apply_rewrite(rewrite).unwrap();
                }
                calls.reverse();
                for func_node in calls.into_iter().map(|x| x.1).dedup() {
                    let Some(func) = hugr.get_optype(func_node).as_func_defn() else {
                        panic!("impossible")
                    };
                    if hugr.linked_inputs(func_node, 0).count() == 0 && func.name != "main" {
                        // eprintln!("Removing func: {}", func.name);
                        let func_hugr = DescendantsGraph::<Node>::try_new(hugr, func_node).unwrap();
                        let to_delete = func_hugr.nodes().dedup().collect_vec();
                        for n in to_delete {
                            hugr.remove_node(n);
                        }
                    }
                }
                hugr.validate(registry)?;
                Ok(())
            })
    }
}
pub struct CallGraph {
    g: petgraph::graph::Graph<Node, Node>,
}
fn func_of_node(hugr: &impl HugrView, node: Node) -> Option<Node> {
    let mut n = node;
    while let Some(parent) = hugr.get_parent(n) {
        if hugr.get_optype(parent).is_func_defn() {
            return Some(parent);
        }
        n = parent;
    }
    None
}
impl CallGraph {
    pub fn new(hugr: &impl HugrView) -> Self {
        let mut g: petgraph::graph::Graph<Node, Node> = Default::default();
        let node_to_cg: HashMap<_, _> = hugr
            .nodes()
            .filter(|&n| (hugr.get_optype(n).is_func_decl() || hugr.get_optype(n).is_func_defn()))
            .map(|n| (n, g.add_node(n)))
            .collect();
        for n in hugr.nodes() {
            if let Some(call) = hugr.get_optype(n).as_call() {
                if let Some(caller_func) = func_of_node(hugr, n) {
                    if let Some((callee_func, _)) =
                        hugr.single_linked_output(n, call.called_function_port())
                    {
                        g.add_edge(node_to_cg[&caller_func], node_to_cg[&callee_func], n);
                    }
                }
            }
        }
        Self { g }
    }
    pub fn iter_nonrecursive(&self) -> Option<impl Iterator<Item = (Node, Node)> + '_> {
        let funcs = petgraph::algo::toposort(&self.g, None).ok()?;
        Some(funcs.into_iter().flat_map(move |f| {
            self.g
                .edges(f)
                .map(move |e| (*e.weight(), self.g[e.target()]))
        }))
    }
}
pub struct InlineRewrite<'a> {
    call: Node,
    func: Node,
    registry: &'a ExtensionRegistry,
}
impl<'a> InlineRewrite<'a> {
    pub fn try_new(
        hugr: &impl HugrView,
        call: Node,
        registry: &'a ExtensionRegistry,
    ) -> Result<Self, InlineRewriteError> {
        if !hugr.valid_node(call) {
            Err(InlineRewriteError::InvalidCall)?
        }
        let Some(call_ot) = hugr.get_optype(call).as_call() else {
            Err(InlineRewriteError::InvalidCall)?
        };
        let Some((func, _)) = hugr.single_linked_output(call, call_ot.called_function_port())
        else {
            Err(InlineRewriteError::InvalidCall)?
        };
        if !hugr.get_optype(func).is_func_defn() {
            Err(InlineRewriteError::InvalidFunction)?
        }
        let r = Self {
            call,
            func,
            registry,
        };
        debug_assert!(r.verify(hugr).is_ok());
        Ok(r)
    }
}
#[derive(Debug, Clone, Error)]
pub enum InlineRewriteError {
    #[error("Invalid Function")]
    InvalidFunction,
    #[error("Invalid Call")]
    InvalidCall,
    #[error("Call does not target func")]
    Invalid,
    #[error(transparent)]
    HugrError(#[from] HugrError),
    #[error(transparent)]
    Validation(#[from] ValidationError),
}
impl<'a> Rewrite for InlineRewrite<'a> {
    type Error = InlineRewriteError;
    type ApplyResult = ();
    const UNCHANGED_ON_FAILURE: bool = true;
    fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
        let Some(call) = h.get_optype(self.call).as_call() else {
            Err(InlineRewriteError::InvalidCall)?
        };
        if !call.type_args.is_empty() {
            Err(InlineRewriteError::InvalidCall)?
        }
        let Some(_) = h.get_optype(self.func).as_func_defn() else {
            Err(InlineRewriteError::InvalidFunction)?
        };
        if let Some((n, _)) = h.single_linked_output(self.call, call.called_function_port()) {
            if self.func != n {
                Err(InlineRewriteError::Invalid)?
            }
        }
        Ok(())
    }
    fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
        self.verify(h)?;
        // dbg!(self.call, self.func);
        let func_hugr = DescendantsGraph::<Node>::try_new(h, self.func)
            .map_err(|_| InlineRewriteError::InvalidFunction)?
            .extract_hugr();
        func_hugr.validate(self.registry)?;
        let call = h.get_optype(self.call).as_call().unwrap().to_owned();
        let call_parent = h.get_parent(self.call).unwrap();
        let signature = call.signature();
        let insertion = h.insert_hugr(call_parent, func_hugr);
        let dfg_node = insertion.new_root;
        let dfg = DFG { signature };
        h.set_num_ports(
            dfg_node,
            dfg.signature().input_count() + dfg.non_df_port_count(Direction::Incoming),
            dfg.signature().output_count() + dfg.non_df_port_count(Direction::Outgoing),
        );
        h.replace_op(dfg_node, dfg)?;
        let connections = h
            .node_inputs(self.call)
            .filter(|&x| x != call.called_function_port())
            .flat_map(|in_p| {
                h.linked_outputs(self.call, in_p)
                    .map(move |(out_n, out_p)| (out_n, out_p, dfg_node, in_p))
            })
            .chain(h.node_outputs(self.call).flat_map(|out_p| {
                h.linked_inputs(self.call, out_p)
                    .map(move |(in_n, in_p)| (dfg_node, out_p, in_n, in_p))
            }))
            .collect_vec();
        for (from_n, from_p, to_n, to_p) in connections {
            h.connect(from_n, from_p, to_n, to_p)
        }
        h.remove_node(self.call);
        Ok(())
    }
    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
        [self.call, self.func].into_iter()
    }
}
Metadata
Metadata
Assignees
Labels
No labels