From 9ef836982d5af0833897c1409f786a05f00779e9 Mon Sep 17 00:00:00 2001 From: Michael J Klein Date: Wed, 6 Nov 2024 10:25:23 -0500 Subject: [PATCH 1/2] chore: proptest for `canonicalize` on infix type expressions (#6269) Co-authored-by: Tom French --- Cargo.lock | 15 +- compiler/noirc_frontend/Cargo.toml | 2 + .../tests/arithmetic_generics.txt | 7 + compiler/noirc_frontend/src/ast/expression.rs | 2 +- compiler/noirc_frontend/src/ast/mod.rs | 4 + compiler/noirc_frontend/src/elaborator/mod.rs | 4 +- .../noirc_frontend/src/elaborator/types.rs | 2 +- .../src/hir/def_collector/dc_mod.rs | 2 +- compiler/noirc_frontend/src/hir_def/types.rs | 15 +- .../src/monomorphization/mod.rs | 4 +- compiler/noirc_frontend/src/tests.rs | 119 +----- .../src/tests/arithmetic_generics.rs | 366 ++++++++++++++++++ 12 files changed, 412 insertions(+), 130 deletions(-) create mode 100644 compiler/noirc_frontend/proptest-regressions/tests/arithmetic_generics.txt create mode 100644 compiler/noirc_frontend/src/tests/arithmetic_generics.rs diff --git a/Cargo.lock b/Cargo.lock index aca7f199bb5..936aa636e86 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2840,7 +2840,7 @@ dependencies = [ "num-bigint", "num-traits", "proptest", - "proptest-derive", + "proptest-derive 0.4.0", "serde", "serde_json", "strum", @@ -2961,6 +2961,8 @@ dependencies = [ "num-bigint", "num-traits", "petgraph", + "proptest", + "proptest-derive 0.5.0", "rangemap", "regex", "rustc-hash", @@ -3511,6 +3513,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "proptest-derive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ff7ff745a347b87471d859a377a9a404361e7efc2a971d73424a6d183c0fc77" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.64", +] + [[package]] name = "quick-error" version = "1.2.3" diff --git a/compiler/noirc_frontend/Cargo.toml b/compiler/noirc_frontend/Cargo.toml index d729dabcb04..581d7f1b61d 100644 --- a/compiler/noirc_frontend/Cargo.toml +++ b/compiler/noirc_frontend/Cargo.toml @@ -36,6 +36,8 @@ strum_macros = "0.24" [dev-dependencies] base64.workspace = true +proptest.workspace = true +proptest-derive = "0.5.0" [features] experimental_parser = [] diff --git a/compiler/noirc_frontend/proptest-regressions/tests/arithmetic_generics.txt b/compiler/noirc_frontend/proptest-regressions/tests/arithmetic_generics.txt new file mode 100644 index 00000000000..80f5c7f1ead --- /dev/null +++ b/compiler/noirc_frontend/proptest-regressions/tests/arithmetic_generics.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc fc27f4091dfa5f938973048209b5fcf22aefa1cfaffaaa3e349f30e9b1f93f49 # shrinks to infix_and_bindings = (((0: numeric bool) % (Numeric(Shared(RefCell { value: Unbound('2, Numeric(bool)) }): bool) + Numeric(Shared(RefCell { value: Unbound('0, Numeric(bool)) }): bool))), [('0, (0: numeric bool)), ('1, (0: numeric bool)), ('2, (0: numeric bool))]) diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index 17fae63bc35..2c8a9b6508d 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -98,7 +98,7 @@ impl UnresolvedGeneric { UnresolvedGeneric::Variable(_) => Ok(Kind::Normal), UnresolvedGeneric::Numeric { typ, .. } => { let typ = self.resolve_numeric_kind_type(typ)?; - Ok(Kind::Numeric(Box::new(typ))) + Ok(Kind::numeric(typ)) } UnresolvedGeneric::Resolved(..) => { panic!("Don't know the kind of a resolved generic here") diff --git a/compiler/noirc_frontend/src/ast/mod.rs b/compiler/noirc_frontend/src/ast/mod.rs index e85563691ba..3c6664dd569 100644 --- a/compiler/noirc_frontend/src/ast/mod.rs +++ b/compiler/noirc_frontend/src/ast/mod.rs @@ -19,6 +19,9 @@ pub use visitor::Visitor; pub use expression::*; pub use function::*; +#[cfg(test)] +use proptest_derive::Arbitrary; + use acvm::FieldElement; pub use docs::*; use noirc_errors::Span; @@ -37,6 +40,7 @@ use crate::{ use acvm::acir::AcirField; use iter_extended::vecmap; +#[cfg_attr(test, derive(Arbitrary))] #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Ord, PartialOrd)] pub enum IntegerBitSize { One, diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 9d6d04c07b0..ee79b62671f 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -641,7 +641,7 @@ impl<'context> Elaborator<'context> { let typ = if unresolved_typ.is_type_expression() { self.resolve_type_inner( unresolved_typ.clone(), - &Kind::Numeric(Box::new(Type::default_int_type())), + &Kind::numeric(Type::default_int_type()), ) } else { self.resolve_type(unresolved_typ.clone()) @@ -654,7 +654,7 @@ impl<'context> Elaborator<'context> { }); self.push_err(unsupported_typ_err); } - Kind::Numeric(Box::new(typ)) + Kind::numeric(typ) } else { Kind::Normal } diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 722573bcd38..b296c4f1805 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -414,7 +414,7 @@ impl<'context> Elaborator<'context> { let kind = self .interner .get_global_let_statement(id) - .map(|let_statement| Kind::Numeric(Box::new(let_statement.r#type))) + .map(|let_statement| Kind::numeric(let_statement.r#type)) .unwrap_or(Kind::u32()); // TODO(https://github.com/noir-lang/noir/issues/6238): diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index 2d15d4927c9..a373441b4e0 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -541,7 +541,7 @@ impl<'a> ModCollector<'a> { name: Rc::new(name.to_string()), type_var: TypeVariable::unbound( type_variable_id, - Kind::Numeric(Box::new(typ)), + Kind::numeric(typ), ), span: name.span(), }); diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 88fc87cc500..a0fea3aa774 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -5,6 +5,9 @@ use std::{ rc::Rc, }; +#[cfg(test)] +use proptest_derive::Arbitrary; + use acvm::{AcirField, FieldElement}; use crate::{ @@ -169,6 +172,11 @@ pub enum Kind { } impl Kind { + // Kind::Numeric constructor helper + pub fn numeric(typ: Type) -> Kind { + Kind::Numeric(Box::new(typ)) + } + pub(crate) fn is_error(&self) -> bool { match self.follow_bindings() { Self::Numeric(typ) => *typ == Type::Error, @@ -196,7 +204,7 @@ impl Kind { } pub(crate) fn u32() -> Self { - Self::Numeric(Box::new(Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo))) + Self::numeric(Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo)) } pub(crate) fn follow_bindings(&self) -> Self { @@ -205,7 +213,7 @@ impl Kind { Self::Normal => Self::Normal, Self::Integer => Self::Integer, Self::IntegerOrField => Self::IntegerOrField, - Self::Numeric(typ) => Self::Numeric(Box::new(typ.follow_bindings())), + Self::Numeric(typ) => Self::numeric(typ.follow_bindings()), } } @@ -671,6 +679,7 @@ impl Shared { /// A restricted subset of binary operators useable on /// type level integers for use in the array length positions of types. +#[cfg_attr(test, derive(Arbitrary))] #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub enum BinaryTypeOperator { Addition, @@ -1422,7 +1431,7 @@ impl Type { if self_kind.unifies(&other_kind) { self_kind } else { - Kind::Numeric(Box::new(Type::Error)) + Kind::numeric(Type::Error) } } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index df687782afb..ce2c58e71c1 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -944,9 +944,7 @@ impl<'interner> Monomorphizer<'interner> { }; let location = self.interner.id_location(expr_id); - if !Kind::Numeric(numeric_typ.clone()) - .unifies(&Kind::Numeric(Box::new(typ.clone()))) - { + if !Kind::Numeric(numeric_typ.clone()).unifies(&Kind::numeric(typ.clone())) { let message = "ICE: Generic's kind does not match expected type"; return Err(MonomorphizationError::InternalError { location, message }); } diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index f66a82a622f..b709436cccc 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -1,6 +1,7 @@ #![cfg(test)] mod aliases; +mod arithmetic_generics; mod bound_checks; mod imports; mod metaprogramming; @@ -16,7 +17,6 @@ mod visibility; // A test harness will allow for more expressive and readable tests use std::collections::BTreeMap; -use acvm::{AcirField, FieldElement}; use fm::FileId; use iter_extended::vecmap; @@ -37,7 +37,6 @@ use crate::hir::def_collector::dc_crate::DefCollector; use crate::hir::def_map::{CrateDefMap, LocalModuleId}; use crate::hir_def::expr::HirExpression; use crate::hir_def::stmt::HirStatement; -use crate::hir_def::types::{BinaryTypeOperator, Type}; use crate::monomorphization::ast::Program; use crate::monomorphization::errors::MonomorphizationError; use crate::monomorphization::monomorphize; @@ -3210,122 +3209,6 @@ fn as_trait_path_syntax_no_impl() { assert!(matches!(&errors[0].0, TypeError(TypeCheckError::NoMatchingImplFound { .. }))); } -#[test] -fn arithmetic_generics_canonicalization_deduplication_regression() { - let source = r#" - struct ArrData { - a: [Field; N], - b: [Field; N + N - 1], - } - - fn main() { - let _f: ArrData<5> = ArrData { - a: [0; 5], - b: [0; 9], - }; - } - "#; - let errors = get_program_errors(source); - assert_eq!(errors.len(), 0); -} - -#[test] -fn arithmetic_generics_checked_cast_zeros() { - let source = r#" - struct W {} - - fn foo(_x: W) -> W<(0 * N) / (N % N)> { - W {} - } - - fn bar(_x: W) -> u1 { - N - } - - fn main() -> pub u1 { - let w_0: W<0> = W {}; - let w: W<_> = foo(w_0); - bar(w) - } - "#; - - let errors = get_program_errors(source); - assert_eq!(errors.len(), 0); - - let monomorphization_error = get_monomorphization_error(source); - assert!(monomorphization_error.is_some()); - - // Expect a CheckedCast (0 % 0) failure - let monomorphization_error = monomorphization_error.unwrap(); - if let MonomorphizationError::UnknownArrayLength { ref length, ref err, location: _ } = - monomorphization_error - { - match length { - Type::CheckedCast { from, to } => { - assert!(matches!(*from.clone(), Type::InfixExpr { .. })); - assert!(matches!(*to.clone(), Type::InfixExpr { .. })); - } - _ => panic!("unexpected length: {:?}", length), - } - assert!(matches!( - err, - TypeCheckError::FailingBinaryOp { op: BinaryTypeOperator::Modulo, lhs: 0, rhs: 0, .. } - )); - } else { - panic!("unexpected error: {:?}", monomorphization_error); - } -} - -#[test] -fn arithmetic_generics_checked_cast_indirect_zeros() { - let source = r#" - struct W {} - - fn foo(_x: W) -> W<(N - N) % (N - N)> { - W {} - } - - fn bar(_x: W) -> Field { - N - } - - fn main() { - let w_0: W<0> = W {}; - let w = foo(w_0); - let _ = bar(w); - } - "#; - - let errors = get_program_errors(source); - assert_eq!(errors.len(), 0); - - let monomorphization_error = get_monomorphization_error(source); - assert!(monomorphization_error.is_some()); - - // Expect a CheckedCast (0 % 0) failure - let monomorphization_error = monomorphization_error.unwrap(); - if let MonomorphizationError::UnknownArrayLength { ref length, ref err, location: _ } = - monomorphization_error - { - match length { - Type::CheckedCast { from, to } => { - assert!(matches!(*from.clone(), Type::InfixExpr { .. })); - assert!(matches!(*to.clone(), Type::InfixExpr { .. })); - } - _ => panic!("unexpected length: {:?}", length), - } - match err { - TypeCheckError::ModuloOnFields { lhs, rhs, .. } => { - assert_eq!(lhs.clone(), FieldElement::zero()); - assert_eq!(rhs.clone(), FieldElement::zero()); - } - _ => panic!("expected ModuloOnFields, but found: {:?}", err), - } - } else { - panic!("unexpected error: {:?}", monomorphization_error); - } -} - #[test] fn infer_globals_to_u32_from_type_use() { let src = r#" diff --git a/compiler/noirc_frontend/src/tests/arithmetic_generics.rs b/compiler/noirc_frontend/src/tests/arithmetic_generics.rs new file mode 100644 index 00000000000..b7c4834a84a --- /dev/null +++ b/compiler/noirc_frontend/src/tests/arithmetic_generics.rs @@ -0,0 +1,366 @@ +#![cfg(test)] + +use proptest::arbitrary::any; +use proptest::collection; +use proptest::prelude::*; +use proptest::result::maybe_ok; +use proptest::strategy; + +use acvm::{AcirField, FieldElement}; + +use super::get_program_errors; +use crate::ast::{IntegerBitSize, Signedness}; +use crate::hir::type_check::TypeCheckError; +use crate::hir_def::types::{BinaryTypeOperator, Kind, Type, TypeVariable, TypeVariableId}; +use crate::monomorphization::errors::MonomorphizationError; +use crate::tests::get_monomorphization_error; + +#[test] +fn arithmetic_generics_canonicalization_deduplication_regression() { + let source = r#" + struct ArrData { + a: [Field; N], + b: [Field; N + N - 1], + } + + fn main() { + let _f: ArrData<5> = ArrData { + a: [0; 5], + b: [0; 9], + }; + } + "#; + let errors = get_program_errors(source); + assert_eq!(errors.len(), 0); +} + +#[test] +fn arithmetic_generics_checked_cast_zeros() { + let source = r#" + struct W {} + + fn foo(_x: W) -> W<(0 * N) / (N % N)> { + W {} + } + + fn bar(_x: W) -> u1 { + N + } + + fn main() -> pub u1 { + let w_0: W<0> = W {}; + let w: W<_> = foo(w_0); + bar(w) + } + "#; + + let errors = get_program_errors(source); + assert_eq!(errors.len(), 0); + + let monomorphization_error = get_monomorphization_error(source); + assert!(monomorphization_error.is_some()); + + // Expect a CheckedCast (0 % 0) failure + let monomorphization_error = monomorphization_error.unwrap(); + if let MonomorphizationError::UnknownArrayLength { ref length, ref err, location: _ } = + monomorphization_error + { + match length { + Type::CheckedCast { from, to } => { + assert!(matches!(*from.clone(), Type::InfixExpr { .. })); + assert!(matches!(*to.clone(), Type::InfixExpr { .. })); + } + _ => panic!("unexpected length: {:?}", length), + } + assert!(matches!( + err, + TypeCheckError::FailingBinaryOp { op: BinaryTypeOperator::Modulo, lhs: 0, rhs: 0, .. } + )); + } else { + panic!("unexpected error: {:?}", monomorphization_error); + } +} + +#[test] +fn arithmetic_generics_checked_cast_indirect_zeros() { + let source = r#" + struct W {} + + fn foo(_x: W) -> W<(N - N) % (N - N)> { + W {} + } + + fn bar(_x: W) -> Field { + N + } + + fn main() { + let w_0: W<0> = W {}; + let w = foo(w_0); + let _ = bar(w); + } + "#; + + let errors = get_program_errors(source); + assert_eq!(errors.len(), 0); + + let monomorphization_error = get_monomorphization_error(source); + assert!(monomorphization_error.is_some()); + + // Expect a CheckedCast (0 % 0) failure + let monomorphization_error = monomorphization_error.unwrap(); + if let MonomorphizationError::UnknownArrayLength { ref length, ref err, location: _ } = + monomorphization_error + { + match length { + Type::CheckedCast { from, to } => { + assert!(matches!(*from.clone(), Type::InfixExpr { .. })); + assert!(matches!(*to.clone(), Type::InfixExpr { .. })); + } + _ => panic!("unexpected length: {:?}", length), + } + match err { + TypeCheckError::ModuloOnFields { lhs, rhs, .. } => { + assert_eq!(lhs.clone(), FieldElement::zero()); + assert_eq!(rhs.clone(), FieldElement::zero()); + } + _ => panic!("expected ModuloOnFields, but found: {:?}", err), + } + } else { + panic!("unexpected error: {:?}", monomorphization_error); + } +} + +prop_compose! { + // maximum_size must be non-zero + fn arbitrary_u128_field_element(maximum_size: u128) + (u128_value in any::()) + -> FieldElement + { + assert!(maximum_size != 0); + FieldElement::from(u128_value % maximum_size) + } +} + +// NOTE: this is roughly the same method from acvm/tests/solver +prop_compose! { + // Use both `u128` and hex proptest strategies + fn arbitrary_field_element() + (u128_or_hex in maybe_ok(any::(), "[0-9a-f]{64}")) + -> FieldElement + { + match u128_or_hex { + Ok(number) => FieldElement::from(number), + Err(hex) => FieldElement::from_hex(&hex).expect("should accept any 32 byte hex string"), + } + } +} + +// Generate (arbitrary_unsigned_type, generator for that type) +fn arbitrary_unsigned_type_with_generator() -> BoxedStrategy<(Type, BoxedStrategy)> { + prop_oneof![ + strategy::Just((Type::FieldElement, arbitrary_field_element().boxed())), + any::().prop_map(|bit_size| { + let typ = Type::Integer(Signedness::Unsigned, bit_size); + let maximum_size = typ.integral_maximum_size().unwrap().to_u128(); + (typ, arbitrary_u128_field_element(maximum_size).boxed()) + }), + strategy::Just((Type::Bool, arbitrary_u128_field_element(1).boxed())), + ] + .boxed() +} + +prop_compose! { + fn arbitrary_variable(typ: Type, num_variables: usize) + (variable_index in any::()) + -> Type { + assert!(num_variables != 0); + let id = TypeVariableId(variable_index % num_variables); + let kind = Kind::numeric(typ.clone()); + let var = TypeVariable::unbound(id, kind); + Type::TypeVariable(var) + } +} + +fn first_n_variables(typ: Type, num_variables: usize) -> impl Iterator { + (0..num_variables).map(move |id| { + let id = TypeVariableId(id); + let kind = Kind::numeric(typ.clone()); + TypeVariable::unbound(id, kind) + }) +} + +fn arbitrary_infix_expr( + typ: Type, + arbitrary_value: BoxedStrategy, + num_variables: usize, +) -> impl Strategy { + let leaf = prop_oneof![ + arbitrary_variable(typ.clone(), num_variables), + arbitrary_value.prop_map(move |value| Type::Constant(value, Kind::numeric(typ.clone()))), + ]; + + leaf.prop_recursive( + 8, // 8 levels deep maximum + 256, // Shoot for maximum size of 256 nodes + 10, // We put up to 10 items per collection + |inner| { + (inner.clone(), any::(), inner) + .prop_map(|(lhs, op, rhs)| Type::InfixExpr(Box::new(lhs), op, Box::new(rhs))) + }, + ) +} + +prop_compose! { + // (infix_expr, type, generator) + fn arbitrary_infix_expr_type_gen(num_variables: usize) + (type_and_gen in arbitrary_unsigned_type_with_generator()) + (infix_expr in arbitrary_infix_expr(type_and_gen.clone().0, type_and_gen.clone().1, num_variables), type_and_gen in Just(type_and_gen)) + -> (Type, Type, BoxedStrategy) { + let (typ, value_generator) = type_and_gen; + (infix_expr, typ, value_generator) + } +} + +prop_compose! { + // (Type::InfixExpr, numeric kind, bindings) + fn arbitrary_infix_expr_with_bindings_sized(num_variables: usize) + (infix_type_gen in arbitrary_infix_expr_type_gen(num_variables)) + (values in collection::vec(infix_type_gen.clone().2, num_variables), infix_type_gen in Just(infix_type_gen)) + -> (Type, Type, Vec<(TypeVariable, Type)>) { + let (infix_expr, typ, _value_generator) = infix_type_gen; + let bindings: Vec<_> = first_n_variables(typ.clone(), num_variables) + .zip(values.iter().map(|value| { + Type::Constant(*value, Kind::numeric(typ.clone())) + })) + .collect(); + (infix_expr, typ, bindings) + } +} + +prop_compose! { + // the lint misfires on 'num_variables' + #[allow(unused_variables)] + fn arbitrary_infix_expr_with_bindings(max_num_variables: usize) + (num_variables in any::().prop_map(move |num_variables| (num_variables % max_num_variables).clamp(1, max_num_variables))) + (infix_type_bindings in arbitrary_infix_expr_with_bindings_sized(num_variables), num_variables in Just(num_variables)) + -> (Type, Type, Vec<(TypeVariable, Type)>) { + infix_type_bindings + } +} + +#[test] +fn instantiate_after_canonicalize_smoke_test() { + let field_element_kind = Kind::numeric(Type::FieldElement); + let x_var = TypeVariable::unbound(TypeVariableId(0), field_element_kind.clone()); + let x_type = Type::TypeVariable(x_var.clone()); + let one = Type::Constant(FieldElement::one(), field_element_kind.clone()); + + let lhs = Type::InfixExpr( + Box::new(x_type.clone()), + BinaryTypeOperator::Addition, + Box::new(one.clone()), + ); + let rhs = + Type::InfixExpr(Box::new(one), BinaryTypeOperator::Addition, Box::new(x_type.clone())); + + // canonicalize + let lhs = lhs.canonicalize(); + let rhs = rhs.canonicalize(); + + // bind vars + let two = Type::Constant(FieldElement::one() + FieldElement::one(), field_element_kind.clone()); + x_var.bind(two); + + // canonicalize (expect constant) + let lhs = lhs.canonicalize(); + let rhs = rhs.canonicalize(); + + // ensure we've canonicalized to constants + assert!(matches!(lhs, Type::Constant(..))); + assert!(matches!(rhs, Type::Constant(..))); + + // ensure result kinds are the same as the original kind + assert_eq!(lhs.kind(), field_element_kind); + assert_eq!(rhs.kind(), field_element_kind); + + // ensure results are the same + assert_eq!(lhs, rhs); +} + +proptest! { + #[test] + // Expect cases that don't resolve to constants, e.g. see + // `arithmetic_generics_checked_cast_indirect_zeros` + #[should_panic(expected = "matches!(infix, Type :: Constant(..))")] + fn instantiate_before_or_after_canonicalize(infix_type_bindings in arbitrary_infix_expr_with_bindings(10)) { + let (infix, typ, bindings) = infix_type_bindings; + + // canonicalize + let infix_canonicalized = infix.canonicalize(); + + // bind vars + for (var, binding) in bindings { + var.bind(binding); + } + + // attempt to canonicalize to a constant + let infix = infix.canonicalize(); + let infix_canonicalized = infix_canonicalized.canonicalize(); + + // ensure we've canonicalized to constants + prop_assert!(matches!(infix, Type::Constant(..))); + prop_assert!(matches!(infix_canonicalized, Type::Constant(..))); + + // ensure result kinds are the same as the original kind + let kind = Kind::numeric(typ); + prop_assert_eq!(infix.kind(), kind.clone()); + prop_assert_eq!(infix_canonicalized.kind(), kind); + + // ensure results are the same + prop_assert_eq!(infix, infix_canonicalized); + } + + #[test] + fn instantiate_before_or_after_canonicalize_checked_cast(infix_type_bindings in arbitrary_infix_expr_with_bindings(10)) { + let (infix, typ, bindings) = infix_type_bindings; + + // wrap in CheckedCast + let infix = Type::CheckedCast { + from: Box::new(infix.clone()), + to: Box::new(infix) + }; + + // canonicalize + let infix_canonicalized = infix.canonicalize(); + + // bind vars + for (var, binding) in bindings { + var.bind(binding); + } + + // attempt to canonicalize to a constant + let infix = infix.canonicalize(); + let infix_canonicalized = infix_canonicalized.canonicalize(); + + // ensure result kinds are the same as the original kind + let kind = Kind::numeric(typ); + prop_assert_eq!(infix.kind(), kind.clone()); + prop_assert_eq!(infix_canonicalized.kind(), kind.clone()); + + // ensure the results are still wrapped in CheckedCast's + match (&infix, &infix_canonicalized) { + (Type::CheckedCast { from, to }, Type::CheckedCast { from: from_canonicalized, to: to_canonicalized }) => { + // ensure from's are the same + prop_assert_eq!(from, from_canonicalized); + + // ensure to's have the same kinds + prop_assert_eq!(to.kind(), kind.clone()); + prop_assert_eq!(to_canonicalized.kind(), kind); + } + _ => { + prop_assert!(false, "expected CheckedCast"); + } + } + } +} From b8654f700b218cc09c5381af65df11ead9ffcdaf Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Wed, 6 Nov 2024 21:15:59 +0100 Subject: [PATCH 2/2] fix: Discard optimisation that would change execution ordering or that is related to call outputs (#6461) Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com> --- .../compiler/optimizers/merge_expressions.rs | 194 ++++++++++++++++-- .../regression_6451/Nargo.toml | 5 + .../regression_6451/Prover.toml | 1 + .../regression_6451/src/main.nr | 23 +++ 4 files changed, 205 insertions(+), 18 deletions(-) create mode 100644 test_programs/execution_success/regression_6451/Nargo.toml create mode 100644 test_programs/execution_success/regression_6451/Prover.toml create mode 100644 test_programs/execution_success/regression_6451/src/main.nr diff --git a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs index 7c0be0fc3fe..c3c80bec2ae 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs @@ -1,7 +1,12 @@ use std::collections::{BTreeMap, BTreeSet, HashMap}; use acir::{ - circuit::{brillig::BrilligInputs, directives::Directive, opcodes::BlockId, Circuit, Opcode}, + circuit::{ + brillig::{BrilligInputs, BrilligOutputs}, + directives::Directive, + opcodes::BlockId, + Circuit, Opcode, + }, native_types::{Expression, Witness}, AcirField, }; @@ -72,23 +77,31 @@ impl MergeExpressionsOptimizer { if let (Opcode::AssertZero(expr_define), Opcode::AssertZero(expr_use)) = (opcode.clone(), second_gate) { - if let Some(expr) = Self::merge(&expr_use, &expr_define, w) { - // sanity check - assert!(i < b); - modified_gates.insert(b, Opcode::AssertZero(expr)); - to_keep = false; - // Update the 'used_witness' map to account for the merge. - for w2 in CircuitSimulator::expr_wit(&expr_define) { - if !circuit_inputs.contains(&w2) { - let mut v = used_witness[&w2].clone(); - v.insert(b); - v.remove(&i); - used_witness.insert(w2, v); + // We cannot merge an expression into an earlier opcode, because this + // would break the 'execution ordering' of the opcodes + // This case can happen because a previous merge would change an opcode + // and eliminate a witness from it, giving new opportunities for this + // witness to be used in only two expressions + // TODO: the missed optimization for the i>b case can be handled by + // - doing this pass again until there is no change, or + // - merging 'b' into 'i' instead + if i < b { + if let Some(expr) = Self::merge(&expr_use, &expr_define, w) { + modified_gates.insert(b, Opcode::AssertZero(expr)); + to_keep = false; + // Update the 'used_witness' map to account for the merge. + for w2 in CircuitSimulator::expr_wit(&expr_define) { + if !circuit_inputs.contains(&w2) { + let mut v = used_witness[&w2].clone(); + v.insert(b); + v.remove(&i); + used_witness.insert(w2, v); + } } + // We need to stop here and continue with the next opcode + // because the merge invalidates the current opcode. + break; } - // We need to stop here and continue with the next opcode - // because the merge invalidate the current opcode - break; } } } @@ -125,6 +138,19 @@ impl MergeExpressionsOptimizer { result } + fn brillig_output_wit(&self, output: &BrilligOutputs) -> BTreeSet { + let mut result = BTreeSet::new(); + match output { + BrilligOutputs::Simple(witness) => { + result.insert(*witness); + } + BrilligOutputs::Array(witnesses) => { + result.extend(witnesses); + } + } + result + } + // Returns the input witnesses used by the opcode fn witness_inputs(&self, opcode: &Opcode) -> BTreeSet { let mut witnesses = BTreeSet::new(); @@ -146,16 +172,22 @@ impl MergeExpressionsOptimizer { Opcode::MemoryInit { block_id: _, init, block_type: _ } => { init.iter().cloned().collect() } - Opcode::BrilligCall { inputs, .. } => { + Opcode::BrilligCall { inputs, outputs, .. } => { for i in inputs { witnesses.extend(self.brillig_input_wit(i)); } + for i in outputs { + witnesses.extend(self.brillig_output_wit(i)); + } witnesses } - Opcode::Call { id: _, inputs, outputs: _, predicate } => { + Opcode::Call { id: _, inputs, outputs, predicate } => { for i in inputs { witnesses.insert(*i); } + for i in outputs { + witnesses.insert(*i); + } if let Some(p) = predicate { witnesses.extend(CircuitSimulator::expr_wit(p)); } @@ -195,3 +227,129 @@ impl MergeExpressionsOptimizer { None } } + +#[cfg(test)] +mod tests { + use crate::compiler::{optimizers::MergeExpressionsOptimizer, CircuitSimulator}; + use acir::{ + acir_field::AcirField, + circuit::{ + brillig::{BrilligFunctionId, BrilligOutputs}, + opcodes::FunctionInput, + Circuit, ExpressionWidth, Opcode, PublicInputs, + }, + native_types::{Expression, Witness}, + FieldElement, + }; + use std::collections::BTreeSet; + + fn check_circuit(circuit: Circuit) { + assert!(CircuitSimulator::default().check_circuit(&circuit)); + let mut merge_optimizer = MergeExpressionsOptimizer::new(); + let acir_opcode_positions = vec![0; 20]; + let (opcodes, _) = + merge_optimizer.eliminate_intermediate_variable(&circuit, acir_opcode_positions); + let mut optimized_circuit = circuit; + optimized_circuit.opcodes = opcodes; + // check that the circuit is still valid after optimization + assert!(CircuitSimulator::default().check_circuit(&optimized_circuit)); + } + + #[test] + fn does_not_eliminate_witnesses_returned_from_brillig() { + let opcodes = vec![ + Opcode::BrilligCall { + id: BrilligFunctionId::default(), + inputs: Vec::new(), + outputs: vec![BrilligOutputs::Simple(Witness(1))], + predicate: None, + }, + Opcode::AssertZero(Expression { + mul_terms: Vec::new(), + linear_combinations: vec![ + (FieldElement::from(2_u128), Witness(0)), + (FieldElement::from(3_u128), Witness(1)), + (FieldElement::from(1_u128), Witness(2)), + ], + q_c: FieldElement::one(), + }), + Opcode::AssertZero(Expression { + mul_terms: Vec::new(), + linear_combinations: vec![ + (FieldElement::from(2_u128), Witness(0)), + (FieldElement::from(2_u128), Witness(1)), + (FieldElement::from(1_u128), Witness(5)), + ], + q_c: FieldElement::one(), + }), + ]; + + let mut private_parameters = BTreeSet::new(); + private_parameters.insert(Witness(0)); + + let circuit = Circuit { + current_witness_index: 1, + expression_width: ExpressionWidth::Bounded { width: 4 }, + opcodes, + private_parameters, + public_parameters: PublicInputs::default(), + return_values: PublicInputs::default(), + assert_messages: Default::default(), + recursive: false, + }; + check_circuit(circuit); + } + + #[test] + fn does_not_attempt_to_merge_into_previous_opcodes() { + let opcodes = vec![ + Opcode::AssertZero(Expression { + mul_terms: vec![(FieldElement::one(), Witness(0), Witness(0))], + linear_combinations: vec![(-FieldElement::one(), Witness(4))], + q_c: FieldElement::zero(), + }), + Opcode::AssertZero(Expression { + mul_terms: vec![(FieldElement::one(), Witness(0), Witness(1))], + linear_combinations: vec![(FieldElement::one(), Witness(5))], + q_c: FieldElement::zero(), + }), + Opcode::AssertZero(Expression { + mul_terms: Vec::new(), + linear_combinations: vec![ + (-FieldElement::one(), Witness(2)), + (FieldElement::one(), Witness(4)), + (FieldElement::one(), Witness(5)), + ], + q_c: FieldElement::zero(), + }), + Opcode::AssertZero(Expression { + mul_terms: Vec::new(), + linear_combinations: vec![ + (FieldElement::one(), Witness(2)), + (-FieldElement::one(), Witness(3)), + (FieldElement::one(), Witness(4)), + (FieldElement::one(), Witness(5)), + ], + q_c: FieldElement::zero(), + }), + Opcode::BlackBoxFuncCall(acir::circuit::opcodes::BlackBoxFuncCall::RANGE { + input: FunctionInput::witness(Witness(3), 32), + }), + ]; + + let mut private_parameters = BTreeSet::new(); + private_parameters.insert(Witness(0)); + private_parameters.insert(Witness(1)); + let circuit = Circuit { + current_witness_index: 5, + expression_width: ExpressionWidth::Bounded { width: 4 }, + opcodes, + private_parameters, + public_parameters: PublicInputs::default(), + return_values: PublicInputs::default(), + assert_messages: Default::default(), + recursive: false, + }; + check_circuit(circuit); + } +} diff --git a/test_programs/execution_success/regression_6451/Nargo.toml b/test_programs/execution_success/regression_6451/Nargo.toml new file mode 100644 index 00000000000..e109747c7f6 --- /dev/null +++ b/test_programs/execution_success/regression_6451/Nargo.toml @@ -0,0 +1,5 @@ +[package] +name = "regression_6451" +type = "bin" +authors = [""] +[dependencies] diff --git a/test_programs/execution_success/regression_6451/Prover.toml b/test_programs/execution_success/regression_6451/Prover.toml new file mode 100644 index 00000000000..74213d097ec --- /dev/null +++ b/test_programs/execution_success/regression_6451/Prover.toml @@ -0,0 +1 @@ +x = 0 \ No newline at end of file diff --git a/test_programs/execution_success/regression_6451/src/main.nr b/test_programs/execution_success/regression_6451/src/main.nr new file mode 100644 index 00000000000..fbee6956dfa --- /dev/null +++ b/test_programs/execution_success/regression_6451/src/main.nr @@ -0,0 +1,23 @@ +fn main(x: Field) { + // Regression test for #6451 + let y = unsafe { empty(x) }; + let mut value = 0; + let term1 = x * x - x * y; + std::as_witness(term1); + value += term1; + let term2 = x * x - y * x; + value += term2; + value.assert_max_bit_size::<1>(); + + // Regression test for Aztec Packages issue #6451 + let y = unsafe { empty(x + 1) }; + let z = y + x + 1; + let z1 = z + y; + assert(z + z1 != 3); + let w = y + 2 * x + 3; + assert(w + z1 != z); +} + +unconstrained fn empty(_: Field) -> Field { + 0 +}