-
Notifications
You must be signed in to change notification settings - Fork 14
feat: Introduce new enum for constant folding; deprecate Value::Function #2060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: acl/insert_hugr_defns
Are you sure you want to change the base?
Changes from all commits
06e6041
d2ff96b
3872d70
be807c7
a22f2fd
4782317
8279835
80fddb7
d5d9564
87d3497
405927d
24db393
0531f76
48653e4
d16cb9f
8c02f9e
fe56b4d
9474c8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,19 +1,134 @@ | ||||||
use std::fmt::Formatter; | ||||||
|
||||||
use std::fmt::Debug; | ||||||
use std::fmt::{Debug, Formatter}; | ||||||
|
||||||
use crate::ops::Value; | ||||||
use crate::types::TypeArg; | ||||||
use crate::ops::constant::{CustomConst, OpaqueValue, Sum}; | ||||||
use crate::types::{SumType, TypeArg}; | ||||||
use crate::{IncomingPort, Node, OutgoingPort, PortIndex}; | ||||||
|
||||||
/// Representation of values used for constant folding. | ||||||
/// See [ConstFold], which is used as `dyn` so we cannot parametrize by | ||||||
/// [HugrNode](crate::core::HugrNode). | ||||||
// Should we be non-exhaustive?? | ||||||
#[derive(Clone, Debug, PartialEq, Default)] | ||||||
pub enum FoldVal { | ||||||
/// Value is unknown, must assume that it could be anything | ||||||
#[default] | ||||||
Unknown, | ||||||
/// A variant of a [SumType] | ||||||
Sum { | ||||||
/// Which variant of the sum type this value is. | ||||||
tag: usize, | ||||||
/// Describes the type of the whole value. | ||||||
// Can we deprecate this immediately? It is only for converting to Value | ||||||
sum_type: SumType, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could move these fields to an inner struct and mark them private i.e. |
||||||
/// A value for each element (type) within the variant | ||||||
items: Vec<FoldVal>, | ||||||
}, | ||||||
/// A constant value defined by an extension | ||||||
Extension(OpaqueValue), | ||||||
/// A function pointer loaded from a [FuncDefn](crate::ops::FuncDefn) or `FuncDecl` | ||||||
LoadedFunction(Node, Vec<TypeArg>), // Deliberately skipping Function(Box<Hugr>) ATM | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need non_exhaustive for |
||||||
} | ||||||
|
||||||
impl<T> From<T> for FoldVal | ||||||
where | ||||||
T: CustomConst, | ||||||
{ | ||||||
fn from(value: T) -> Self { | ||||||
Self::Extension(value.into()) | ||||||
} | ||||||
} | ||||||
|
||||||
impl FoldVal { | ||||||
/// Returns a constant "false" value, i.e. the first variant of Sum((), ()). | ||||||
pub const fn false_val() -> Self { | ||||||
Self::Sum { | ||||||
tag: 0, | ||||||
sum_type: SumType::Unit { size: 2 }, | ||||||
items: vec![], | ||||||
} | ||||||
} | ||||||
|
||||||
/// Returns a constant "true" value, i.e. the second variant of Sum((), ()). | ||||||
pub const fn true_val() -> Self { | ||||||
Self::Sum { | ||||||
tag: 1, | ||||||
sum_type: SumType::Unit { size: 2 }, | ||||||
items: vec![], | ||||||
} | ||||||
} | ||||||
|
||||||
/// Returns a constant boolean - either [Self::false_val] or [Self::true_val] | ||||||
pub const fn from_bool(b: bool) -> Self { | ||||||
if b { | ||||||
Self::true_val() | ||||||
} else { | ||||||
Self::false_val() | ||||||
} | ||||||
} | ||||||
|
||||||
/// Extract the specified type of [CustomConst] fro this instance, if it is one | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T> { | ||||||
let Self::Extension(e) = self else { | ||||||
return None; | ||||||
}; | ||||||
e.value().downcast_ref() | ||||||
} | ||||||
} | ||||||
|
||||||
use crate::IncomingPort; | ||||||
use crate::OutgoingPort; | ||||||
impl TryFrom<FoldVal> for Value { | ||||||
type Error = Option<Node>; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add docstring to explain error type |
||||||
|
||||||
use crate::ops; | ||||||
fn try_from(value: FoldVal) -> Result<Self, Self::Error> { | ||||||
match value { | ||||||
FoldVal::Unknown => Err(None), | ||||||
FoldVal::Sum { | ||||||
tag, | ||||||
sum_type, | ||||||
items, | ||||||
} => { | ||||||
let values = items | ||||||
.into_iter() | ||||||
.map(Value::try_from) | ||||||
.collect::<Result<Vec<_>, _>>()?; | ||||||
Ok(Value::Sum(Sum { | ||||||
tag, | ||||||
values, | ||||||
sum_type, | ||||||
})) | ||||||
} | ||||||
FoldVal::Extension(e) => Ok(Value::Extension { e }), | ||||||
FoldVal::LoadedFunction(node, _) => Err(Some(node)), | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
impl From<Value> for FoldVal { | ||||||
fn from(value: Value) -> Self { | ||||||
match value { | ||||||
Value::Extension { e } => FoldVal::Extension(e), | ||||||
#[allow(deprecated)] // remove when Value::Function removed | ||||||
Value::Function { .. } => FoldVal::Unknown, | ||||||
Value::Sum(Sum { | ||||||
tag, | ||||||
values, | ||||||
sum_type, | ||||||
}) => { | ||||||
let items = values.into_iter().map(FoldVal::from).collect(); | ||||||
FoldVal::Sum { | ||||||
tag, | ||||||
sum_type, | ||||||
items, | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
/// Output of constant folding an operation, None indicates folding was either | ||||||
/// not possible or unsuccessful. An empty vector indicates folding was | ||||||
/// successful and no values are output. | ||||||
pub type ConstFoldResult = Option<Vec<(OutgoingPort, ops::Value)>>; | ||||||
pub type ConstFoldResult = Option<Vec<(OutgoingPort, Value)>>; | ||||||
|
||||||
/// Tag some output constants with [`OutgoingPort`] inferred from the ordering. | ||||||
pub fn fold_out_row(consts: impl IntoIterator<Item = Value>) -> ConstFoldResult { | ||||||
|
@@ -25,7 +140,9 @@ pub fn fold_out_row(consts: impl IntoIterator<Item = Value>) -> ConstFoldResult | |||||
Some(vec) | ||||||
} | ||||||
|
||||||
/// Trait implemented by extension operations that can perform constant folding. | ||||||
#[deprecated(note = "Use ConstFolder")] | ||||||
/// Old trait implemented by extension operations that can perform constant folding. | ||||||
/// Deprecated: see [ConstFolder] | ||||||
pub trait ConstFold: Send + Sync { | ||||||
/// Given type arguments `type_args` and | ||||||
/// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s, | ||||||
|
@@ -37,14 +154,58 @@ pub trait ConstFold: Send + Sync { | |||||
) -> ConstFoldResult; | ||||||
} | ||||||
|
||||||
/// Trait implemented by extension operations that can perform constant folding. | ||||||
pub trait ConstFolder: Send + Sync { | ||||||
/// Given type arguments `type_args` and [`FoldVal`]s for each input, | ||||||
/// update the outputs (these will be initialized to [FoldVal::Unknown]). | ||||||
fn fold(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]); | ||||||
} | ||||||
|
||||||
pub(super) fn fold_vals_to_indexed_vals<P: From<usize>>(fvs: &[FoldVal]) -> Vec<(P, Value)> { | ||||||
fvs.iter() | ||||||
.cloned() | ||||||
.enumerate() | ||||||
.filter_map(|(p, fv)| Some((p.into(), fv.try_into().ok()?))) | ||||||
.collect::<Vec<_>>() | ||||||
} | ||||||
|
||||||
#[allow(deprecated)] // Legacy conversion routine, remove when ConstFold removed | ||||||
fn do_fold( | ||||||
old_fold: &impl ConstFold, | ||||||
type_args: &[TypeArg], | ||||||
inputs: &[FoldVal], | ||||||
outputs: &mut [FoldVal], | ||||||
) { | ||||||
let consts = fold_vals_to_indexed_vals(inputs); | ||||||
let outs = old_fold.fold(type_args, &consts); | ||||||
for (p, v) in outs.unwrap_or_default() { | ||||||
outputs[p.index()] = v.into(); | ||||||
} | ||||||
} | ||||||
|
||||||
#[allow(deprecated)] // Remove when ConstFold removed | ||||||
impl<T: ConstFold> ConstFolder for T { | ||||||
fn fold(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { | ||||||
do_fold(self, type_args, inputs, outputs) | ||||||
} | ||||||
} | ||||||
|
||||||
#[allow(deprecated)] // Remove when ConstFold removed | ||||||
impl Debug for Box<dyn ConstFold> { | ||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { | ||||||
write!(f, "<custom constant folding>") | ||||||
} | ||||||
} | ||||||
|
||||||
impl Debug for Box<dyn ConstFolder> { | ||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { | ||||||
write!(f, "<custom constant folding>") | ||||||
} | ||||||
} | ||||||
|
||||||
/// Blanket implementation for functions that only require the constants to | ||||||
/// evaluate - type arguments are not relevant. | ||||||
#[allow(deprecated)] // Remove when ConstFold removed | ||||||
impl<T> ConstFold for T | ||||||
where | ||||||
T: Fn(&[(crate::IncomingPort, crate::ops::Value)]) -> ConstFoldResult + Send + Sync, | ||||||
|
@@ -60,14 +221,25 @@ where | |||||
|
||||||
type FoldFn = dyn Fn(&[TypeArg], &[(IncomingPort, Value)]) -> ConstFoldResult + Send + Sync; | ||||||
|
||||||
/// Type holding a boxed const-folding function. | ||||||
/// Legacy type holding a boxed const-folding function. | ||||||
/// Deprecated: use [BoxedFolder] instead. | ||||||
#[deprecated(note = "Use BoxedFolder")] | ||||||
pub struct Folder { | ||||||
/// Const-folding function. | ||||||
pub folder: Box<FoldFn>, | ||||||
} | ||||||
|
||||||
#[allow(deprecated)] // Remove when ConstFold removed | ||||||
impl ConstFold for Folder { | ||||||
fn fold(&self, type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult { | ||||||
(self.folder)(type_args, consts) | ||||||
} | ||||||
} | ||||||
|
||||||
pub struct BoxedFolder(Box<dyn Fn(&[TypeArg], &[FoldVal], &mut [FoldVal]) + Send + Sync>); | ||||||
|
||||||
impl ConstFolder for BoxedFolder { | ||||||
fn fold(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) { | ||||||
self.0(type_args, inputs, outputs) | ||||||
} | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think so, "Extension" should be covering the points of extension