Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,7 @@ impl<'a> Context<'a> {
self.make_term(table::Term::Apply(symbol, args))
}

#[allow(deprecated)] // Remove when Value::Function removed
Value::Function { hugr } => {
let outer_hugr = std::mem::replace(&mut self.hugr, hugr);
let outer_node_to_id = std::mem::take(&mut self.node_to_id);
Expand Down
5 changes: 4 additions & 1 deletion hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ pub mod resolution;
pub mod simple_op;
mod type_def;

pub use const_fold::{ConstFold, ConstFoldResult, Folder, fold_out_row};
#[deprecated(note = "Use ConstFolder")]
#[allow(deprecated)] // Remove when ConstFold removed
pub use const_fold::{ConstFold, Folder};
pub use const_fold::{ConstFoldResult, ConstFolder, FoldVal, fold_out_row};
pub use op_def::{
CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc,
ValidateJustArgs, ValidateTypeArgs,
Expand Down
192 changes: 182 additions & 10 deletions hugr-core/src/extension/const_fold.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,134 @@
use std::fmt::Formatter;

use std::fmt::Debug;
use std::fmt::{Debug, Formatter};

use crate::ops::Value;
use crate::types::TypeArg;
use crate::ops::constant::{CustomConst, OpaqueValue, Sum};
use crate::types::{SumType, TypeArg};
use crate::{IncomingPort, Node, OutgoingPort, PortIndex};

/// Representation of values used for constant folding.
/// See [ConstFold], which is used as `dyn` so we cannot parametrize by
/// [HugrNode](crate::core::HugrNode).
// Should we be non-exhaustive??
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think so, "Extension" should be covering the points of extension

#[derive(Clone, Debug, PartialEq, Default)]
pub enum FoldVal {
/// Value is unknown, must assume that it could be anything
#[default]
Unknown,
/// A variant of a [SumType]
Sum {
/// Which variant of the sum type this value is.
tag: usize,
/// Describes the type of the whole value.
// Can we deprecate this immediately? It is only for converting to Value
sum_type: SumType,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could move these fields to an inner struct and mark them private i.e. Sum(SumVal)

/// A value for each element (type) within the variant
items: Vec<FoldVal>,
},
/// A constant value defined by an extension
Extension(OpaqueValue),
/// A function pointer loaded from a [FuncDefn](crate::ops::FuncDefn) or `FuncDecl`
LoadedFunction(Node, Vec<TypeArg>), // Deliberately skipping Function(Box<Hugr>) ATM
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need non_exhaustive for Function(Box<Hugr>), or a stub implementation that errors?

}

impl<T> From<T> for FoldVal
where
T: CustomConst,
{
fn from(value: T) -> Self {
Self::Extension(value.into())
}
}

impl FoldVal {
/// Returns a constant "false" value, i.e. the first variant of Sum((), ()).
pub const fn false_val() -> Self {
Self::Sum {
tag: 0,
sum_type: SumType::Unit { size: 2 },
items: vec![],
}
}

/// Returns a constant "true" value, i.e. the second variant of Sum((), ()).
pub const fn true_val() -> Self {
Self::Sum {
tag: 1,
sum_type: SumType::Unit { size: 2 },
items: vec![],
}
}

/// Returns a constant boolean - either [Self::false_val] or [Self::true_val]
pub const fn from_bool(b: bool) -> Self {
if b {
Self::true_val()
} else {
Self::false_val()
}
}

/// Extract the specified type of [CustomConst] fro this instance, if it is one
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Extract the specified type of [CustomConst] fro this instance, if it is one
/// Extract the specified type of [CustomConst] for this instance, if it is one

pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T> {
let Self::Extension(e) = self else {
return None;
};
e.value().downcast_ref()
}
}

use crate::IncomingPort;
use crate::OutgoingPort;
impl TryFrom<FoldVal> for Value {
type Error = Option<Node>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add docstring to explain error type


use crate::ops;
fn try_from(value: FoldVal) -> Result<Self, Self::Error> {
match value {
FoldVal::Unknown => Err(None),
FoldVal::Sum {
tag,
sum_type,
items,
} => {
let values = items
.into_iter()
.map(Value::try_from)
.collect::<Result<Vec<_>, _>>()?;
Ok(Value::Sum(Sum {
tag,
values,
sum_type,
}))
}
FoldVal::Extension(e) => Ok(Value::Extension { e }),
FoldVal::LoadedFunction(node, _) => Err(Some(node)),
}
}
}

impl From<Value> for FoldVal {
fn from(value: Value) -> Self {
match value {
Value::Extension { e } => FoldVal::Extension(e),
#[allow(deprecated)] // remove when Value::Function removed
Value::Function { .. } => FoldVal::Unknown,
Value::Sum(Sum {
tag,
values,
sum_type,
}) => {
let items = values.into_iter().map(FoldVal::from).collect();
FoldVal::Sum {
tag,
sum_type,
items,
}
}
}
}
}

/// Output of constant folding an operation, None indicates folding was either
/// not possible or unsuccessful. An empty vector indicates folding was
/// successful and no values are output.
pub type ConstFoldResult = Option<Vec<(OutgoingPort, ops::Value)>>;
pub type ConstFoldResult = Option<Vec<(OutgoingPort, Value)>>;

/// Tag some output constants with [`OutgoingPort`] inferred from the ordering.
pub fn fold_out_row(consts: impl IntoIterator<Item = Value>) -> ConstFoldResult {
Expand All @@ -25,7 +140,9 @@ pub fn fold_out_row(consts: impl IntoIterator<Item = Value>) -> ConstFoldResult
Some(vec)
}

/// Trait implemented by extension operations that can perform constant folding.
#[deprecated(note = "Use ConstFolder")]
/// Old trait implemented by extension operations that can perform constant folding.
/// Deprecated: see [ConstFolder]
pub trait ConstFold: Send + Sync {
/// Given type arguments `type_args` and
/// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s,
Expand All @@ -37,14 +154,58 @@ pub trait ConstFold: Send + Sync {
) -> ConstFoldResult;
}

/// Trait implemented by extension operations that can perform constant folding.
pub trait ConstFolder: Send + Sync {
/// Given type arguments `type_args` and [`FoldVal`]s for each input,
/// update the outputs (these will be initialized to [FoldVal::Unknown]).
fn fold(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]);
}

pub(super) fn fold_vals_to_indexed_vals<P: From<usize>>(fvs: &[FoldVal]) -> Vec<(P, Value)> {
fvs.iter()
.cloned()
.enumerate()
.filter_map(|(p, fv)| Some((p.into(), fv.try_into().ok()?)))
.collect::<Vec<_>>()
}

#[allow(deprecated)] // Legacy conversion routine, remove when ConstFold removed
fn do_fold(
old_fold: &impl ConstFold,
type_args: &[TypeArg],
inputs: &[FoldVal],
outputs: &mut [FoldVal],
) {
let consts = fold_vals_to_indexed_vals(inputs);
let outs = old_fold.fold(type_args, &consts);
for (p, v) in outs.unwrap_or_default() {
outputs[p.index()] = v.into();
}
}

#[allow(deprecated)] // Remove when ConstFold removed
impl<T: ConstFold> ConstFolder for T {
fn fold(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) {
do_fold(self, type_args, inputs, outputs)
}
}

#[allow(deprecated)] // Remove when ConstFold removed
impl Debug for Box<dyn ConstFold> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "<custom constant folding>")
}
}

impl Debug for Box<dyn ConstFolder> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "<custom constant folding>")
}
}

/// Blanket implementation for functions that only require the constants to
/// evaluate - type arguments are not relevant.
#[allow(deprecated)] // Remove when ConstFold removed
impl<T> ConstFold for T
where
T: Fn(&[(crate::IncomingPort, crate::ops::Value)]) -> ConstFoldResult + Send + Sync,
Expand All @@ -60,14 +221,25 @@ where

type FoldFn = dyn Fn(&[TypeArg], &[(IncomingPort, Value)]) -> ConstFoldResult + Send + Sync;

/// Type holding a boxed const-folding function.
/// Legacy type holding a boxed const-folding function.
/// Deprecated: use [BoxedFolder] instead.
#[deprecated(note = "Use BoxedFolder")]
pub struct Folder {
/// Const-folding function.
pub folder: Box<FoldFn>,
}

#[allow(deprecated)] // Remove when ConstFold removed
impl ConstFold for Folder {
fn fold(&self, type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
(self.folder)(type_args, consts)
}
}

pub struct BoxedFolder(Box<dyn Fn(&[TypeArg], &[FoldVal], &mut [FoldVal]) + Send + Sync>);

impl ConstFolder for BoxedFolder {
fn fold(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) {
self.0(type_args, inputs, outputs)
}
}
41 changes: 31 additions & 10 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@ use std::sync::{Arc, Weak};

use serde_with::serde_as;

use crate::envelope::serde_with::AsStringEnvelope;
use crate::extension::const_fold::fold_vals_to_indexed_vals;
use crate::ops::{OpName, OpNameRef, Value};
use crate::types::type_param::{TypeArg, TypeParam, check_type_args};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
use crate::{Hugr, IncomingPort, PortIndex};

use super::const_fold::FoldVal;
use super::{
ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionSet,
ConstFoldResult, ConstFolder, Extension, ExtensionBuildError, ExtensionId, ExtensionSet,
SignatureError,
};

use crate::Hugr;
use crate::envelope::serde_with::AsStringEnvelope;
use crate::ops::{OpName, OpNameRef};
use crate::types::type_param::{TypeArg, TypeParam, check_type_args};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
mod serialize_signature_func;

/// Trait necessary for binary computations of `OpDef` signature
Expand Down Expand Up @@ -327,7 +330,7 @@ pub struct OpDef {

/// Operations can optionally implement [`ConstFold`] to implement constant folding.
#[serde(skip)]
constant_folder: Option<Box<dyn ConstFold>>,
constant_folder: Option<Box<dyn ConstFolder>>,
}

impl OpDef {
Expand Down Expand Up @@ -457,19 +460,37 @@ impl OpDef {

/// Set the constant folding function for this Op, which can evaluate it
/// given constant inputs.
pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) {
pub fn set_constant_folder(&mut self, fold: impl ConstFolder + 'static) {
self.constant_folder = Some(Box::new(fold));
}

/// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given
/// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s.
#[must_use]
#[deprecated(note = "use const_fold")]
pub fn constant_fold(
&self,
type_args: &[TypeArg],
consts: &[(crate::IncomingPort, crate::ops::Value)],
consts: &[(IncomingPort, Value)],
) -> ConstFoldResult {
(self.constant_folder.as_ref())?.fold(type_args, consts)
let folder = self.constant_folder.as_ref()?;
let sig = self.compute_signature(type_args).unwrap();
let mut inputs = vec![FoldVal::Unknown; sig.input_count()];
for (p, v) in consts {
inputs[p.index()] = v.clone().into();
}
let mut outputs = vec![FoldVal::Unknown; sig.output_count()];
folder.fold(type_args, &inputs, &mut outputs);
Some(fold_vals_to_indexed_vals(&outputs))
}

/// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given
/// [FoldVal] values for each input, and update the outputs, which should be
/// initialised to [FoldVal::Unknown].
pub fn const_fold(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) {
if let Some(cf) = self.constant_folder.as_ref() {
cf.fold(type_args, inputs, outputs)
}
}

/// Returns a reference to the signature function of this [`OpDef`].
Expand Down
15 changes: 6 additions & 9 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::extension::simple_op::{
MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, try_from_name,
};
use crate::extension::{
ConstFold, ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDefBound,
ConstFolder, ExtensionId, FoldVal, OpDef, SignatureError, SignatureFunc, TypeDefBound,
};
use crate::ops::OpName;
use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName};
Expand Down Expand Up @@ -585,7 +585,8 @@ pub enum TupleOpDef {
UnpackTuple,
}

impl ConstFold for TupleOpDef {
#[allow(deprecated)] // TODO: need a way to handle types of tuples. Or drop that SumType...
impl super::ConstFold for TupleOpDef {
fn fold(
&self,
_type_args: &[TypeArg],
Expand Down Expand Up @@ -823,13 +824,9 @@ impl MakeOpDef for NoopDef {
}
}

impl ConstFold for NoopDef {
fn fold(
&self,
_type_args: &[TypeArg],
consts: &[(crate::IncomingPort, Value)],
) -> crate::extension::ConstFoldResult {
fold_out_row([consts.first()?.1.clone()])
impl ConstFolder for NoopDef {
fn fold(&self, _type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) {
outputs[0] = inputs[0].clone()
}
}

Expand Down
Loading
Loading