Skip to content
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,7 @@ impl<'a> Context<'a> {
self.make_term(table::Term::Apply(symbol, args))
}

#[allow(deprecated)] // Remove when Value::Function removed
Value::Function { hugr } => {
let outer_hugr = std::mem::replace(&mut self.hugr, hugr);
let outer_node_to_id = std::mem::take(&mut self.node_to_id);
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub mod resolution;
pub mod simple_op;
mod type_def;

pub use const_fold::{ConstFold, ConstFoldResult, Folder, fold_out_row};
pub use const_fold::{ConstFold, ConstFoldResult, FoldVal, Folder, fold_out_row};
pub use op_def::{
CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc,
ValidateJustArgs, ValidateTypeArgs,
Expand Down
151 changes: 143 additions & 8 deletions hugr-core/src/extension/const_fold.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,134 @@
use std::fmt::Formatter;

use std::fmt::Debug;
use std::fmt::{Debug, Formatter};

use crate::ops::Value;
use crate::types::TypeArg;
use crate::ops::constant::{CustomConst, OpaqueValue, Sum};
use crate::types::{SumType, TypeArg};
use crate::{IncomingPort, Node, OutgoingPort, PortIndex};

/// Representation of values used for constant folding.
/// See [ConstFold], which is used as `dyn` so we cannot parametrize by
/// [HugrNode](crate::core::HugrNode).
// Should we be non-exhaustive??
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think so, "Extension" should be covering the points of extension

#[derive(Clone, Debug, PartialEq, Default)]
pub enum FoldVal {
/// Value is unknown, must assume that it could be anything
#[default]
Unknown,
/// A variant of a [SumType]
Sum {
/// Which variant of the sum type this value is.
tag: usize,
/// Describes the type of the whole value.
// Can we deprecate this immediately? It is only for converting to Value
sum_type: SumType,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could move these fields to an inner struct and mark them private i.e. Sum(SumVal)

/// A value for each element (type) within the variant
items: Vec<FoldVal>,
},
/// A constant value defined by an extension
Extension(OpaqueValue),
/// A function pointer loaded from a [FuncDefn](crate::ops::FuncDefn) or `FuncDecl`
LoadedFunction(Node, Vec<TypeArg>), // Deliberately skipping Function(Box<Hugr>) ATM
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need non_exhaustive for Function(Box<Hugr>), or a stub implementation that errors?

}

impl<T> From<T> for FoldVal
where
T: CustomConst,
{
fn from(value: T) -> Self {
Self::Extension(value.into())
}
}

impl FoldVal {
/// Returns a constant "false" value, i.e. the first variant of Sum((), ()).
pub const fn false_val() -> Self {
Self::Sum {
tag: 0,
sum_type: SumType::Unit { size: 2 },
items: vec![],
}
}

/// Returns a constant "true" value, i.e. the second variant of Sum((), ()).
pub const fn true_val() -> Self {
Self::Sum {
tag: 1,
sum_type: SumType::Unit { size: 2 },
items: vec![],
}
}

use crate::IncomingPort;
use crate::OutgoingPort;
/// Returns a constant boolean - either [Self::false_val] or [Self::true_val]
pub const fn from_bool(b: bool) -> Self {
if b {
Self::true_val()
} else {
Self::false_val()
}
}

/// Extract the specified type of [CustomConst] fro this instance, if it is one
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Extract the specified type of [CustomConst] fro this instance, if it is one
/// Extract the specified type of [CustomConst] for this instance, if it is one

pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T> {
let Self::Extension(e) = self else {
return None;
};
e.value().downcast_ref()
}
}

impl TryFrom<FoldVal> for Value {
type Error = Option<Node>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add docstring to explain error type


fn try_from(value: FoldVal) -> Result<Self, Self::Error> {
match value {
FoldVal::Unknown => Err(None),
FoldVal::Sum {
tag,
sum_type,
items,
} => {
let values = items
.into_iter()
.map(Value::try_from)
.collect::<Result<Vec<_>, _>>()?;
Ok(Value::Sum(Sum {
tag,
values,
sum_type,
}))
}
FoldVal::Extension(e) => Ok(Value::Extension { e }),
FoldVal::LoadedFunction(node, _) => Err(Some(node)),
}
}
}

use crate::ops;
impl From<Value> for FoldVal {
fn from(value: Value) -> Self {
match value {
Value::Extension { e } => FoldVal::Extension(e),
#[allow(deprecated)] // remove when Value::Function removed
Value::Function { .. } => FoldVal::Unknown,
Value::Sum(Sum {
tag,
values,
sum_type,
}) => {
let items = values.into_iter().map(FoldVal::from).collect();
FoldVal::Sum {
tag,
sum_type,
items,
}
}
}
}
}

/// Output of constant folding an operation, None indicates folding was either
/// not possible or unsuccessful. An empty vector indicates folding was
/// successful and no values are output.
pub type ConstFoldResult = Option<Vec<(OutgoingPort, ops::Value)>>;
pub type ConstFoldResult = Option<Vec<(OutgoingPort, Value)>>;

/// Tag some output constants with [`OutgoingPort`] inferred from the ordering.
pub fn fold_out_row(consts: impl IntoIterator<Item = Value>) -> ConstFoldResult {
Expand All @@ -27,9 +142,29 @@ pub fn fold_out_row(consts: impl IntoIterator<Item = Value>) -> ConstFoldResult

/// Trait implemented by extension operations that can perform constant folding.
pub trait ConstFold: Send + Sync {
/// Given type arguments `type_args` and [`FoldVal`]s for each input,
/// update the outputs (these will be initialized to [FoldVal::Unknown]).
///
/// Defaults to calling [Self::fold] with those arguments that can be converted ---
/// [FoldVal::LoadedFunction]s will be lost as these are not representable as [Value]s.
fn fold2(&self, type_args: &[TypeArg], inputs: &[FoldVal], outputs: &mut [FoldVal]) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this follow a mutable output pattern? Would some other process have set some values already that this one would leave unchanged?

let consts = inputs
.iter()
.cloned()
.enumerate()
.filter_map(|(p, fv)| Some((p.into(), fv.try_into().ok()?)))
.collect::<Vec<_>>();
#[allow(deprecated)] // remove this when fold is removed
let outs = self.fold(type_args, &consts);
for (p, v) in outs.unwrap_or_default() {
outputs[p.index()] = v.into();
}
}

/// Given type arguments `type_args` and
/// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s,
/// try to evaluate the operation.
#[deprecated(note = "Use fold2")]
fn fold(
&self,
type_args: &[TypeArg],
Expand Down
30 changes: 24 additions & 6 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ use std::sync::{Arc, Weak};

use serde_with::serde_as;

use crate::envelope::serde_with::AsStringEnvelope;
use crate::ops::{OpName, OpNameRef, Value};
use crate::types::type_param::{TypeArg, TypeParam, check_type_args};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
use crate::{Hugr, IncomingPort};

use super::const_fold::FoldVal;
use super::{
ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionSet,
SignatureError,
};

use crate::Hugr;
use crate::envelope::serde_with::AsStringEnvelope;
use crate::ops::{OpName, OpNameRef};
use crate::types::type_param::{TypeArg, TypeParam, check_type_args};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
mod serialize_signature_func;

/// Trait necessary for binary computations of `OpDef` signature
Expand Down Expand Up @@ -464,14 +466,30 @@ impl OpDef {
/// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given
/// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s.
#[must_use]
#[deprecated(note = "use constant_fold2")]
pub fn constant_fold(
&self,
type_args: &[TypeArg],
consts: &[(crate::IncomingPort, crate::ops::Value)],
consts: &[(IncomingPort, Value)],
) -> ConstFoldResult {
#[allow(deprecated)] // we are in deprecated function, remove at same time
(self.constant_folder.as_ref())?.fold(type_args, consts)
}

/// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given
/// [FoldVal] values for each input, and update the outputs, which should be
/// initialised to [FoldVal::Unknown].
pub fn constant_fold2(
&self,
type_args: &[TypeArg],
inputs: &[FoldVal],
outputs: &mut [FoldVal],
) {
if let Some(cf) = self.constant_folder.as_ref() {
cf.fold2(type_args, inputs, outputs)
}
}

/// Returns a reference to the signature function of this [`OpDef`].
#[must_use]
pub fn signature_func(&self) -> &SignatureFunc {
Expand Down
34 changes: 12 additions & 22 deletions hugr-core/src/extension/prelude/generic.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,19 @@
use std::str::FromStr;
use std::sync::{Arc, Weak};

use crate::extension::OpDef;
use crate::extension::SignatureFunc;
use crate::Extension;
use crate::extension::prelude::usize_custom_t;
use crate::extension::simple_op::{
HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
};
use crate::extension::{ConstFold, ExtensionId};
use crate::ops::ExtensionOp;
use crate::ops::OpName;
use crate::extension::{ConstFold, ExtensionId, OpDef, SignatureError, SignatureFunc};
use crate::ops::{ExtensionOp, OpName};
use crate::type_row;
use crate::types::FuncValueType;

use crate::types::Type;

use crate::extension::SignatureError;

use crate::types::PolyFuncTypeRV;

use crate::Extension;
use crate::types::type_param::TypeArg;
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{FuncValueType, PolyFuncTypeRV, Type};

use super::PRELUDE;
use super::{ConstUsize, PRELUDE_ID};
use crate::types::type_param::TypeParam;

/// Name of the operation for loading generic `BoundedNat` parameters.
pub static LOAD_NAT_OP_ID: OpName = OpName::new_inline("load_nat");
Expand Down Expand Up @@ -161,10 +150,11 @@ impl HasConcrete for LoadNatDef {
#[cfg(test)]
mod tests {
use crate::{
HugrView, OutgoingPort,
HugrView,
builder::{DFGBuilder, Dataflow, DataflowHugr, inout_sig},
extension::FoldVal,
extension::prelude::{ConstUsize, usize_t},
ops::{OpType, constant},
ops::OpType,
type_row,
types::TypeArg,
};
Expand Down Expand Up @@ -201,10 +191,10 @@ mod tests {
let optype: OpType = op.into();

if let OpType::ExtensionOp(ext_op) = optype {
let result = ext_op.constant_fold(&[]);
let exp_port: OutgoingPort = 0.into();
let exp_val: constant::Value = ConstUsize::new(5).into();
assert_eq!(result, Some(vec![(exp_port, exp_val)]));
let mut out = [FoldVal::Unknown];
ext_op.constant_fold2(&[], &mut out);
let exp_val: FoldVal = ConstUsize::new(5).into();
assert_eq!(out, [exp_val])
} else {
panic!()
}
Expand Down
1 change: 1 addition & 0 deletions hugr-core/src/extension/resolution/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ fn collect_value_exts(
let typ = e.get_type();
collect_type_exts(&typ, used_extensions, missing_extensions);
}
#[allow(deprecated)] // remove when Value::Function removed
Value::Function { hugr: _ } => {
// The extensions used by nested hugrs do not need to be counted for the root hugr.
}
Expand Down
1 change: 1 addition & 0 deletions hugr-core/src/extension/resolution/types_mut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ pub(super) fn resolve_value_exts(
});
}
}
#[allow(deprecated)] // remove when Value::Function removed
Value::Function { hugr } => {
// We don't need to add the nested hugr's extensions to the main one here,
// but we run resolution on it independently.
Expand Down
2 changes: 2 additions & 0 deletions hugr-core/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ fn roundtrip_sumtype(#[case] sum_type: SumType) {
#[case(Value::extension(ConstInt::new_u(2,1).unwrap()))]
#[case(Value::sum(1,[Value::extension(ConstInt::new_u(2,1).unwrap())], SumType::new([vec![], vec![INT_TYPES[2].clone()]])).unwrap())]
#[case(Value::tuple([Value::false_val(), Value::extension(ConstInt::new_s(2,1).unwrap())]))]
#[allow(deprecated)] // remove when Value::Function removed
#[case(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap())]
fn roundtrip_value(#[case] value: Value) {
check_testing_roundtrip(value);
Expand Down Expand Up @@ -511,6 +512,7 @@ fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) {
#[case(ops::AliasDefn { name: "aliasdefn".into(), definition: Type::new_unit_sum(4)})]
#[case(ops::AliasDecl { name: "aliasdecl".into(), bound: TypeBound::Any})]
#[case(ops::Const::new(Value::false_val()))]
#[allow(deprecated)] // remove when Value::Function removed
#[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))]
#[case(ops::Input::new(vec![Type::new_var_use(3,TypeBound::Copyable)]))]
#[case(ops::Output::new(vec![Type::new_function(FuncValueType::new_endo(type_row![]))]))]
Expand Down
Loading
Loading