From 937629729a2912987b5de2439a8306c479f8005a Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 29 Oct 2025 15:06:37 +0530 Subject: [PATCH 1/2] Allow constants to contain optional types. --- .../src/interpreter_value/core_function.rs | 34 +++++++++++++++-- compiler/ast/src/statement/definition/mod.rs | 2 +- compiler/passes/src/const_propagation/ast.rs | 29 ++++++++++++--- compiler/passes/src/option_lowering/mod.rs | 12 +++++- compiler/passes/src/storage_lowering/ast.rs | 11 ++++-- compiler/passes/src/type_checking/ast.rs | 6 --- .../compiler/option/bad_types_fail.out | 5 --- tests/expectations/compiler/option/consts.out | 37 +++++++++++++++++++ .../compiler/option/consts_fail.out | 35 ------------------ .../option/{consts_fail.leo => consts.leo} | 0 10 files changed, 109 insertions(+), 62 deletions(-) create mode 100644 tests/expectations/compiler/option/consts.out delete mode 100644 tests/expectations/compiler/option/consts_fail.out rename tests/tests/compiler/option/{consts_fail.leo => consts.leo} (100%) diff --git a/compiler/ast/src/interpreter_value/core_function.rs b/compiler/ast/src/interpreter_value/core_function.rs index b362327d5c2..835854f53ac 100644 --- a/compiler/ast/src/interpreter_value/core_function.rs +++ b/compiler/ast/src/interpreter_value/core_function.rs @@ -28,6 +28,7 @@ use crate::{ CoreFunction, Expression, Type, + halt2, interpreter_value::{ExpectTc, Value}, tc_fail2, }; @@ -300,12 +301,22 @@ pub fn evaluate_core_function( helper.mapping_get(program, name, &key).is_some().into() } CoreFunction::OptionalUnwrap => { - // TODO - return Ok(None); + let (is_some, val) = unpack_option_struct(helper.pop_value().expect_tc(span)?.contents, span)?; + + if is_some { + Value { id: None, contents: ValueVariants::Svm(val.into()) } + } else { + halt2!(span, "called unwrap on a none value") + } } CoreFunction::OptionalUnwrapOr => { - // TODO - return Ok(None); + let (is_some, val) = unpack_option_struct(helper.pop_value().expect_tc(span)?.contents, span)?; + + let ValueVariants::Svm(SvmValue::Plaintext(or)) = helper.pop_value().expect_tc(span)?.contents else { + tc_fail2!() + }; + + Value { id: None, contents: ValueVariants::Svm(if is_some { val.into() } else { or.into() }) } } CoreFunction::VectorPush | CoreFunction::VectorLen @@ -333,3 +344,18 @@ pub fn evaluate_core_function( Ok(Some(value)) } + +fn unpack_option_struct(v: ValueVariants, span: Span) -> Result<(bool, SvmPlaintext)> { + let ValueVariants::Svm(SvmValue::Plaintext(Plaintext::Struct(members, _))) = v else { tc_fail2!() }; + + let get = |name: &str| { + let id = Symbol::intern(name).to_string().parse::().expect("type checker failure"); + members.get(&id).expect_tc(span) + }; + + let SvmPlaintext::Literal(SvmLiteral::Boolean(is_some), _) = get("is_some")? else { tc_fail2!() }; + + let val = get("val")?.clone(); + + Ok((**is_some, val)) +} diff --git a/compiler/ast/src/statement/definition/mod.rs b/compiler/ast/src/statement/definition/mod.rs index 4e6d7676f55..35f6c6047d2 100644 --- a/compiler/ast/src/statement/definition/mod.rs +++ b/compiler/ast/src/statement/definition/mod.rs @@ -22,7 +22,7 @@ use itertools::Itertools as _; use serde::{Deserialize, Serialize}; use std::fmt; -/// A `let` or `const` declaration statement. +/// A `let` declaration statement. #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] pub struct DefinitionStatement { /// The bindings / variable names to declare. diff --git a/compiler/passes/src/const_propagation/ast.rs b/compiler/passes/src/const_propagation/ast.rs index cc3881b0c7b..888dd9fc867 100644 --- a/compiler/passes/src/const_propagation/ast.rs +++ b/compiler/passes/src/const_propagation/ast.rs @@ -405,11 +405,9 @@ impl AstReconstructor for ConstPropagationVisitor<'_> { ) -> (Expression, Self::AdditionalOutput) { let type_info = self.state.type_table.get(&input.id()); - // If this is an optional, then unwrap it first. - let type_info = type_info.as_ref().map(|ty| match ty { - Type::Optional(opt) => *opt.inner.clone(), - _ => ty.clone(), - }); + if let Some(Type::Optional(_)) = type_info { + return (input.into(), None); + } if let Ok(value) = interpreter_value::literal_to_value(&input, &type_info) { // If we know the type of an unsuffixed literal, might as well change it to a suffixed literal. This way, we @@ -536,7 +534,26 @@ impl AstReconstructor for ConstPropagationVisitor<'_> { } fn reconstruct_const(&mut self, mut input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) { - if matches!(input.type_, Type::Optional(_)) { + // If there is any optional in the type definition, leave it for now. + fn recursive_optional(type_: &Type, slf: &ConstPropagationVisitor) -> bool { + match type_ { + Type::Array(array_type) => recursive_optional(array_type.element_type(), slf), + Type::Composite(composite_type) => { + if let Some(cmp) = + slf.state.symbol_table.lookup_struct(composite_type.path.absolute_path().as_ref()) + { + cmp.members.iter().any(|mbr| recursive_optional(&mbr.type_, slf)) + } else { + false + } + } + Type::Optional(_) => true, + Type::Tuple(tuple_type) => tuple_type.elements.iter().any(|element| recursive_optional(element, slf)), + _ => false, + } + } + + if recursive_optional(&input.type_, self) { return (input.into(), None); } diff --git a/compiler/passes/src/option_lowering/mod.rs b/compiler/passes/src/option_lowering/mod.rs index f6f652d6776..4e5b534240c 100644 --- a/compiler/passes/src/option_lowering/mod.rs +++ b/compiler/passes/src/option_lowering/mod.rs @@ -57,7 +57,15 @@ //! After this pass, no `T?` types remain in the program: all optional values are represented explicitly //! as structs with `is_some` and `val` fields. -use crate::{Pass, PathResolution, SymbolTable, SymbolTableCreation, TypeChecking, TypeCheckingInput}; +use crate::{ + ConstPropagation, + Pass, + PathResolution, + SymbolTable, + SymbolTableCreation, + TypeChecking, + TypeCheckingInput, +}; use leo_ast::{ArrayType, CompositeType, ProgramReconstructor as _, Type}; use leo_errors::Result; @@ -101,6 +109,8 @@ impl Pass for OptionLowering { PathResolution::do_pass((), state)?; SymbolTableCreation::do_pass((), state)?; TypeChecking::do_pass(input.clone(), state)?; + // Now there are no more optionals, we can now evaluate the core unwrap functions of const optionals in the interpreter. + ConstPropagation::do_pass((), state)?; Ok(()) } diff --git a/compiler/passes/src/storage_lowering/ast.rs b/compiler/passes/src/storage_lowering/ast.rs index 0b6a3315637..2ec331248bc 100644 --- a/compiler/passes/src/storage_lowering/ast.rs +++ b/compiler/passes/src/storage_lowering/ast.rs @@ -15,6 +15,7 @@ // along with the Leo library. If not, see . use super::StorageLoweringVisitor; +use crate::VariableType; use leo_ast::*; use leo_span::{Span, Symbol, sym}; @@ -654,10 +655,12 @@ impl leo_ast::AstReconstructor for StorageLoweringVisitor<'_> { fn reconstruct_path(&mut self, input: Path, _additional: &()) -> (Expression, Self::AdditionalOutput) { // Check if this path corresponds to a global symbol. - let Some(var) = self.state.symbol_table.lookup_global(&Location::new(self.program, input.absolute_path())) - else { - // Nothing to do - return (input.into(), vec![]); + let var = match self.state.symbol_table.lookup_global(&Location::new(self.program, input.absolute_path())) { + Some(var) if var.declaration == VariableType::Storage => var, + _ => { + // Nothing to do + return (input.into(), vec![]); + } }; match &var.type_ { diff --git a/compiler/passes/src/type_checking/ast.rs b/compiler/passes/src/type_checking/ast.rs index 6b2fd8db52e..04713e97355 100644 --- a/compiler/passes/src/type_checking/ast.rs +++ b/compiler/passes/src/type_checking/ast.rs @@ -1971,12 +1971,6 @@ impl AstVisitor for TypeCheckingVisitor<'_> { fn visit_const(&mut self, input: &ConstDeclaration) { self.visit_type(&input.type_); - // For now, consts that contain optional types are not supported. - // TODO: remove this restriction by supporting const evaluation of optionals including `None`. - if self.contains_optional_type(&input.type_) { - self.emit_err(TypeCheckerError::const_cannot_be_optional(input.span)); - } - // Check that the type of the definition is not a unit type, singleton tuple type, or nested tuple type. match &input.type_ { // If the type is an empty tuple, return an error. diff --git a/tests/expectations/compiler/option/bad_types_fail.out b/tests/expectations/compiler/option/bad_types_fail.out index 200df1e889e..40a1c9b0079 100644 --- a/tests/expectations/compiler/option/bad_types_fail.out +++ b/tests/expectations/compiler/option/bad_types_fail.out @@ -1,8 +1,3 @@ -Error [ETYC0372163]: Constants cannot have an optional type or a type that contains an optional - --> compiler-test:16:5 - | - 16 | const BAD_CONST_ARRAY: [u8?; 2] = [1u8, 2i8]; // ERROR - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Error [ETYC0372117]: Expected type `u8?` but type `i8` was found. --> compiler-test:16:45 | diff --git a/tests/expectations/compiler/option/consts.out b/tests/expectations/compiler/option/consts.out new file mode 100644 index 00000000000..a1fcd30dc00 --- /dev/null +++ b/tests/expectations/compiler/option/consts.out @@ -0,0 +1,37 @@ +program const_optionals.aleo; + +struct Optional__8hhrPm4c3KB: + is_some as boolean; + val as u8; + +struct MyStruct: + x as Optional__8hhrPm4c3KB; + +function const_single_optional: + assert.eq true true; + output 10u8 as u8.private; + +function const_array_of_optionals: + assert.eq true true; + output 2u8 as u8.private; + +function const_tuple_with_optional: + assert.eq true true; + output 99u8 as u8.private; + +function const_nested_array: + assert.eq true true; + assert.eq true true; + output 7u8 as u8.private; + +function const_struct_with_optional: + assert.eq true true; + output 88u8 as u8.private; + +function const_array_of_structs: + assert.eq true true; + assert.eq true true; + output 33u8 as u8.private; + +constructor: + assert.eq edition 0u16; diff --git a/tests/expectations/compiler/option/consts_fail.out b/tests/expectations/compiler/option/consts_fail.out deleted file mode 100644 index d950964d0b4..00000000000 --- a/tests/expectations/compiler/option/consts_fail.out +++ /dev/null @@ -1,35 +0,0 @@ -Error [ETYC0372163]: Constants cannot have an optional type or a type that contains an optional - --> compiler-test:6:5 - | - 6 | const A: u8? = 10u8; - | ^^^^^^^^^^^^^^^^^^^^ -Error [ETYC0372163]: Constants cannot have an optional type or a type that contains an optional - --> compiler-test:7:5 - | - 7 | const B: [u8?; 3] = [1u8, 2u8, 3u8]; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Error [ETYC0372163]: Constants cannot have an optional type or a type that contains an optional - --> compiler-test:8:5 - | - 8 | const C: (u8, u8?) = (42u8, 99u8); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Error [ETYC0372163]: Constants cannot have an optional type or a type that contains an optional - --> compiler-test:9:5 - | - 9 | const D: [[u8?; 2]; 2] = [[4u8, 5u8], [6u8, 7u8]]; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Error [ETYC0372163]: Constants cannot have an optional type or a type that contains an optional - --> compiler-test:15:5 - | - 15 | const E: MyStruct = MyStruct { x: 88u8 }; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Error [ETYC0372163]: Constants cannot have an optional type or a type that contains an optional - --> compiler-test:16:5 - | - 16 | const F: [MyStruct; 2] = [MyStruct { x: 11u8 }, MyStruct { x: 22u8 }]; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Error [ETYC0372163]: Constants cannot have an optional type or a type that contains an optional - --> compiler-test:33:9 - | - 33 | const D_local: [[u8?; 2]; 2] = [[1u8, 2u8], [3u8, 4u8]]; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/tests/compiler/option/consts_fail.leo b/tests/tests/compiler/option/consts.leo similarity index 100% rename from tests/tests/compiler/option/consts_fail.leo rename to tests/tests/compiler/option/consts.leo From e02eda5da30f7e5ea381a66780f67969ffa28c00 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 29 Oct 2025 18:12:58 +0530 Subject: [PATCH 2/2] Add the support for option types in the leo interpreter. --- compiler/ast/src/expressions/mod.rs | 75 +++++ compiler/ast/src/interpreter_value/value.rs | 26 +- compiler/ast/src/types/optional.rs | 41 ++- compiler/passes/src/option_lowering/mod.rs | 40 +-- .../passes/src/option_lowering/visitor.rs | 2 +- interpreter/src/cursor.rs | 257 +++++++++++++++--- interpreter/src/interpreter.rs | 13 +- 7 files changed, 369 insertions(+), 85 deletions(-) diff --git a/compiler/ast/src/expressions/mod.rs b/compiler/ast/src/expressions/mod.rs index 3c15365ad3d..e2736438d88 100644 --- a/compiler/ast/src/expressions/mod.rs +++ b/compiler/ast/src/expressions/mod.rs @@ -129,6 +129,81 @@ impl Default for Expression { } } +#[allow(clippy::type_complexity)] +pub fn zero_value_expression( + ty: &Type, + span: Span, + node_builder: &NodeBuilder, + struct_lookup: &dyn Fn(&[Symbol]) -> Vec<(Symbol, Type)>, +) -> Option { + let id = node_builder.next_id(); + + match ty { + // Numeric types + Type::Integer(IntegerType::I8) => Some(Literal::integer(IntegerType::I8, "0".to_string(), span, id).into()), + Type::Integer(IntegerType::I16) => Some(Literal::integer(IntegerType::I16, "0".to_string(), span, id).into()), + Type::Integer(IntegerType::I32) => Some(Literal::integer(IntegerType::I32, "0".to_string(), span, id).into()), + Type::Integer(IntegerType::I64) => Some(Literal::integer(IntegerType::I64, "0".to_string(), span, id).into()), + Type::Integer(IntegerType::I128) => Some(Literal::integer(IntegerType::I128, "0".to_string(), span, id).into()), + Type::Integer(IntegerType::U8) => Some(Literal::integer(IntegerType::U8, "0".to_string(), span, id).into()), + Type::Integer(IntegerType::U16) => Some(Literal::integer(IntegerType::U16, "0".to_string(), span, id).into()), + Type::Integer(IntegerType::U32) => Some(Literal::integer(IntegerType::U32, "0".to_string(), span, id).into()), + Type::Integer(IntegerType::U64) => Some(Literal::integer(IntegerType::U64, "0".to_string(), span, id).into()), + Type::Integer(IntegerType::U128) => Some(Literal::integer(IntegerType::U128, "0".to_string(), span, id).into()), + + // Boolean + Type::Boolean => Some(Literal::boolean(false, span, id).into()), + + // Field, Group, Scalar + Type::Field => Some(Literal::field("0".to_string(), span, id).into()), + Type::Group => Some(Literal::group("0".to_string(), span, id).into()), + Type::Scalar => Some(Literal::scalar("0".to_string(), span, id).into()), + + // Structs (composite types) + Type::Composite(composite_type) => { + let path = &composite_type.path; + let members = struct_lookup(&path.absolute_path()); + + let struct_members = members + .into_iter() + .map(|(symbol, member_type)| { + let member_id = node_builder.next_id(); + let zero_expr = zero_value_expression(&member_type, span, node_builder, struct_lookup)?; + + Some(StructVariableInitializer { + span, + id: member_id, + identifier: crate::Identifier::new(symbol, node_builder.next_id()), + expression: Some(zero_expr), + }) + }) + .collect::>>()?; + + Some(Expression::Struct(StructExpression { + span, + id, + path: path.clone(), + const_arguments: composite_type.const_arguments.clone(), + members: struct_members, + })) + } + + // Arrays + Type::Array(array_type) => { + let element_ty = &array_type.element_type; + + let element_expr = zero_value_expression(element_ty, span, node_builder, struct_lookup)?; + + Some(Expression::Repeat( + RepeatExpression { span, id, expr: element_expr, count: *array_type.length.clone() }.into(), + )) + } + + // Other types are not expected or supported just yet + _ => None, + } +} + impl Node for Expression { fn span(&self) -> Span { use Expression::*; diff --git a/compiler/ast/src/interpreter_value/value.rs b/compiler/ast/src/interpreter_value/value.rs index 6d5103ecfd7..979e363fae2 100644 --- a/compiler/ast/src/interpreter_value/value.rs +++ b/compiler/ast/src/interpreter_value/value.rs @@ -77,13 +77,13 @@ pub struct StructContents { #[derive(Clone, Debug, Default, Eq, PartialEq, Hash)] pub struct Value { pub id: Option, - pub(crate) contents: ValueVariants, + pub contents: ValueVariants, } #[derive(Clone, Default, Debug, Eq, PartialEq)] // SnarkVM's Value is large, but that's okay. #[allow(clippy::large_enum_variant)] -pub(crate) enum ValueVariants { +pub enum ValueVariants { #[default] Unit, Svm(SvmValue), @@ -93,7 +93,7 @@ pub(crate) enum ValueVariants { String(String), } -#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum AsyncExecution { AsyncFunctionCall { function: Location, @@ -102,10 +102,28 @@ pub enum AsyncExecution { AsyncBlock { containing_function: Location, // The function that contains the async block. block: crate::NodeID, - names: BTreeMap, Value>, // Use a `BTreeMap` here because `HashMap` does not implement `Hash`. + names: BTreeMap, (Value, Option)>, // Use a `BTreeMap` here because `HashMap` does not implement `Hash`. }, } +impl Hash for AsyncExecution { + fn hash(&self, state: &mut H) { + match self { + Self::AsyncBlock { containing_function, block, names } => { + 0u8.hash(state); + containing_function.hash(state); + block.hash(state); + names.iter().map(|(k, v)| (k.clone(), v.0.clone())).collect::>().hash(state); + } + Self::AsyncFunctionCall { function, arguments } => { + 1u8.hash(state); + function.hash(state); + arguments.hash(state); + } + } + } +} + impl fmt::Display for AsyncExecution { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, " async call to ")?; diff --git a/compiler/ast/src/types/optional.rs b/compiler/ast/src/types/optional.rs index 5158564a890..d446ad94224 100644 --- a/compiler/ast/src/types/optional.rs +++ b/compiler/ast/src/types/optional.rs @@ -14,8 +14,10 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::Type; +use crate::{ArrayType, CompositeType, Type}; +use itertools::Itertools; +use leo_span::Symbol; use serde::{Deserialize, Serialize}; use std::fmt; @@ -25,6 +27,43 @@ pub struct OptionalType { pub inner: Box, } +pub fn make_optional_struct_symbol(ty: &Type) -> Symbol { + // Step 1: Extract a usable type name + fn display_type(ty: &Type) -> String { + match ty { + Type::Address + | Type::Field + | Type::Group + | Type::Scalar + | Type::Signature + | Type::Boolean + | Type::Integer(..) => format!("{ty}"), + Type::Array(ArrayType { element_type, length }) => { + format!("[{}; {length}]", display_type(element_type)) + } + Type::Composite(CompositeType { path, .. }) => { + format!("::{}", path.absolute_path().iter().format("::")) + } + + Type::Tuple(_) + | Type::Optional(_) + | Type::Mapping(_) + | Type::Numeric + | Type::Identifier(_) + | Type::Future(_) + | Type::Vector(_) + | Type::String + | Type::Err + | Type::Unit => { + panic!("unexpected inner type in optional struct name") + } + } + } + + // Step 3: Build symbol that ends with `?`. + Symbol::intern(&format!("\"{}?\"", display_type(ty))) +} + impl fmt::Display for OptionalType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}?", self.inner) diff --git a/compiler/passes/src/option_lowering/mod.rs b/compiler/passes/src/option_lowering/mod.rs index 4e5b534240c..1a4c40c4a87 100644 --- a/compiler/passes/src/option_lowering/mod.rs +++ b/compiler/passes/src/option_lowering/mod.rs @@ -67,12 +67,11 @@ use crate::{ TypeCheckingInput, }; -use leo_ast::{ArrayType, CompositeType, ProgramReconstructor as _, Type}; +use leo_ast::ProgramReconstructor as _; use leo_errors::Result; use leo_span::Symbol; use indexmap::IndexMap; -use itertools::Itertools; mod ast; @@ -115,40 +114,3 @@ impl Pass for OptionLowering { Ok(()) } } - -pub fn make_optional_struct_symbol(ty: &Type) -> Symbol { - // Step 1: Extract a usable type name - fn display_type(ty: &Type) -> String { - match ty { - Type::Address - | Type::Field - | Type::Group - | Type::Scalar - | Type::Signature - | Type::Boolean - | Type::Integer(..) => format!("{ty}"), - Type::Array(ArrayType { element_type, length }) => { - format!("[{}; {length}]", display_type(element_type)) - } - Type::Composite(CompositeType { path, .. }) => { - format!("::{}", path.absolute_path().iter().format("::")) - } - - Type::Tuple(_) - | Type::Optional(_) - | Type::Mapping(_) - | Type::Numeric - | Type::Identifier(_) - | Type::Future(_) - | Type::Vector(_) - | Type::String - | Type::Err - | Type::Unit => { - panic!("unexpected inner type in optional struct name") - } - } - } - - // Step 3: Build symbol that ends with `?`. - Symbol::intern(&format!("\"{}?\"", display_type(ty))) -} diff --git a/compiler/passes/src/option_lowering/visitor.rs b/compiler/passes/src/option_lowering/visitor.rs index 0888f592286..a8ebc14c610 100644 --- a/compiler/passes/src/option_lowering/visitor.rs +++ b/compiler/passes/src/option_lowering/visitor.rs @@ -177,7 +177,7 @@ impl OptionLoweringVisitor<'_> { /// If the struct for this type already exists, it’s reused; otherwise, a new one is created. /// Returns the `Symbol` for the struct name. pub fn insert_optional_wrapper_struct(&mut self, ty: &Type) -> Symbol { - let struct_name = crate::make_optional_struct_symbol(ty); + let struct_name = make_optional_struct_symbol(ty); self.new_structs.entry(struct_name).or_insert_with(|| Composite { identifier: Identifier::new(struct_name, self.state.node_builder.next_id()), diff --git a/interpreter/src/cursor.rs b/interpreter/src/cursor.rs index 18d93a0a5b2..889ded985b4 100644 --- a/interpreter/src/cursor.rs +++ b/interpreter/src/cursor.rs @@ -38,14 +38,21 @@ use leo_ast::{ AsyncExecution, CoreFunctionHelper, Value, + ValueVariants, evaluate_binary, evaluate_core_function, evaluate_unary, literal_to_value, }, + make_optional_struct_symbol, + zero_value_expression, }; use leo_errors::{InterpreterHalt, Result}; -use leo_span::{Span, Symbol, sym}; +use leo_span::{ + Span, + Symbol, + sym::{self}, +}; use snarkvm::prelude::{ Address, @@ -75,7 +82,7 @@ pub struct FunctionContext { path: Vec, program: Symbol, pub caller: Value, - names: HashMap, Value>, + names: HashMap, (Value, Option)>, accumulated_futures: Vec, is_async: bool, } @@ -98,7 +105,7 @@ impl ContextStack { program: Symbol, caller: Value, is_async: bool, - names: HashMap, Value>, // a map of variable names that are already known + names: HashMap, (Value, Option)>, // a map of variable names that are already known ) { if self.current_len == self.contexts.len() { self.contexts.push(FunctionContext { @@ -138,9 +145,9 @@ impl ContextStack { mem::take(&mut self.contexts[self.current_len - 1].accumulated_futures) } - fn set(&mut self, path: &[Symbol], value: Value) { + fn set(&mut self, path: &[Symbol], value: Value, type_: Option) { assert!(self.current_len > 0); - self.last_mut().unwrap().names.insert(path.to_vec(), value); + self.last_mut().unwrap().names.insert(path.to_vec(), (value, type_)); } pub fn add_future(&mut self, future: Vec) { @@ -207,7 +214,7 @@ pub enum Element { DelayedAsyncBlock { program: Symbol, block: NodeID, - names: HashMap, Value>, + names: HashMap, (Value, Option)>, }, } @@ -257,9 +264,9 @@ pub struct Cursor { pub async_blocks: HashMap, /// Consts are stored here. - pub globals: HashMap, + pub globals: HashMap, - pub user_values: HashMap, Value>, + pub user_values: HashMap, (Value, Option)>, pub mappings: HashMap>, @@ -446,10 +453,10 @@ impl Cursor { let full_name = self.to_absolute_path(&path.as_symbols()); - let mut leo_value = self.lookup(&full_name).unwrap_or(Value::make_unit()); + let mut leo_value = self.lookup(&full_name).unwrap_or((Value::make_unit(), None)); // Do an ad hoc evaluation of the lhs of the assignment to determine its type. - let mut temp_value = leo_value.clone(); + let (mut temp_value, mut type_) = leo_value.clone(); let mut indices_iter = indices.iter(); for place in places.iter().rev() { @@ -457,12 +464,25 @@ impl Cursor { Expression::ArrayAccess(_access) => { let next_index = indices_iter.next().unwrap(); temp_value = temp_value.array_index(next_index.as_u32().unwrap() as usize).unwrap(); + type_ = type_ + .and_then(|t| if let Type::Array(at) = t { Some(at.element_type().clone()) } else { None }); } Expression::TupleAccess(access) => { temp_value = temp_value.tuple_index(access.index.value()).unwrap(); + type_ = type_.and_then(|t| { + if let Type::Tuple(tt) = t { Some(tt.elements()[access.index.value()].clone()) } else { None } + }); } Expression::MemberAccess(access) => { temp_value = temp_value.member_access(access.name.name).unwrap(); + type_ = type_.and_then(|t| { + if let Type::Composite(ct) = t { + let members = self.structs.get(&self.to_absolute_path(ct.path.as_symbols().as_slice())); + members.and_then(|m| m.get(&access.name.name).cloned()) + } else { + None + } + }); } Expression::Path(_path) => // temp_value is already set to leo_value @@ -474,8 +494,8 @@ impl Cursor { let ty = temp_value.get_numeric_type(); let value = value.resolve_if_unsuffixed(&ty, place.span())?; - Self::set_place(value, &mut leo_value, &mut places.into_iter().rev(), &mut indices.into_iter())?; - self.set_variable(&full_name, leo_value); + Self::set_place(value, &mut leo_value.0, &mut places.into_iter().rev(), &mut indices.into_iter())?; + self.set_variable(&full_name, leo_value.0, type_); Ok(()) } @@ -515,12 +535,15 @@ impl Cursor { } } - fn lookup(&self, name: &[Symbol]) -> Option { + fn lookup(&self, name: &[Symbol]) -> Option<(Value, Option)> { if let Some(context) = self.contexts.last() { - let option_value = - context.names.get(name).or_else(|| self.globals.get(&Location::new(context.program, name.to_vec()))); + let option_value = context.names.get(name).cloned().or_else(|| { + self.globals + .get(&Location::new(context.program, name.to_vec())) + .map(|(v, t)| (v.clone(), Some(t.clone()))) + }); if option_value.is_some() { - return option_value.cloned(); + return option_value; } }; @@ -547,11 +570,11 @@ impl Cursor { self.functions.get(&Location::new(program, name.to_vec())).cloned() } - fn set_variable(&mut self, path: &[Symbol], value: Value) { + fn set_variable(&mut self, path: &[Symbol], value: Value, type_: Option) { if self.contexts.len() > 0 { - self.contexts.set(path, value); + self.contexts.set(path, value, type_); } else { - self.user_values.insert(path.to_vec(), value); + self.user_values.insert(path.to_vec(), (value, type_)); } } @@ -694,24 +717,78 @@ impl Cursor { } Statement::Assign(assign) if step == 0 => { // Step 0: push the expression frame and any array index expression frames. - push(&assign.value, &None); - let mut place = &assign.place; - loop { + let place = &assign.place; + let mut to_push = Vec::new(); + + // Reverse extract the expected type of the assignee. + fn recurse<'a, 'b: 'a>( + lookup: &'a impl Fn(&[Symbol]) -> Option, + mut place: &'b Expression, + to_push: &'a mut Vec<&'b Expression>, + ) -> Option { match place { leo_ast::Expression::ArrayAccess(access) => { - push(&access.index, &None); place = &access.array; + to_push.push(&access.index); + let type_ = recurse(lookup, place, to_push); + type_.map(|t| { + let Type::Array(at) = t else { panic!("Can't happen") }; + at.element_type().clone() + }) } - leo_ast::Expression::Path(..) => break, + leo_ast::Expression::Path(p) => lookup(&p.as_symbols()), leo_ast::Expression::MemberAccess(access) => { place = &access.inner; + let type_ = recurse(lookup, place, to_push); + type_.and_then(|t| { + let Type::Composite(ct) = t else { panic!("Can't happen") }; + lookup(&ct.path.as_symbols()) + }) } leo_ast::Expression::TupleAccess(access) => { place = &access.tuple; + let type_ = recurse(lookup, place, to_push); + type_.map(|t| { + let Type::Tuple(tt) = t else { panic!("Can't happen") }; + tt.elements()[access.index.value()].clone() + }) } _ => panic!("Can't happen"), } } + + let abs = |name: &[Symbol]| { + if let Some(context) = self.contexts.last() { + let mut full_name = context.path.clone(); + full_name.pop(); // This pops the function name, keeping only the module prefix + full_name.extend(name); + full_name + } else { + name.to_vec() + } + }; + + let lookup = |name: &[Symbol]| { + if let Some(context) = self.contexts.last() { + let option_value = context.names.get(name).cloned().or_else(|| { + self.globals + .get(&Location::new(context.program, name.to_vec())) + .map(|(v, t)| (v.clone(), Some(t.clone()))) + }); + if let Some(optv) = option_value { + return optv.1; + } + }; + + self.user_values.get(name).cloned().unwrap().1 + }; + + let type_ = recurse(&|sym| lookup(abs(sym).as_slice()), place, &mut to_push); + push(&assign.value, &type_); + for tp in to_push.into_iter() { + push(tp, &None) + } + false } Statement::Assign(assign) if step == 1 => { @@ -778,7 +855,7 @@ impl Cursor { } Statement::Const(const_) if step == 1 => { let value = self.pop_value()?; - self.set_variable(&self.to_absolute_path(&[const_.place.name]), value); + self.set_variable(&self.to_absolute_path(&[const_.place.name]), value, Some(const_.type_.clone())); true } Statement::Definition(definition) if step == 0 => { @@ -788,12 +865,24 @@ impl Cursor { Statement::Definition(definition) if step == 1 => { let value = self.pop_value()?; match &definition.place { - DefinitionPlace::Single(id) => self.set_variable(&self.to_absolute_path(&[id.name]), value), + DefinitionPlace::Single(id) => { + self.set_variable(&self.to_absolute_path(&[id.name]), value, definition.type_.clone()) + } DefinitionPlace::Multiple(ids) => { + let maybe_types = definition.type_.clone().map(|t| { + if let Type::Tuple(tt) = t + && tt.length().eq(&ids.len()) + { + tt + } else { + panic!("") + } + }); for (i, id) in ids.iter().enumerate() { self.set_variable( &self.to_absolute_path(&[id.name]), value.tuple_index(i).expect("Place for definition should be a tuple."), + maybe_types.as_ref().map(|tt| tt.elements()[i].clone()), ); } } @@ -822,7 +911,11 @@ impl Cursor { true } else { let new_start = start.inc_wrapping().expect_tc(iteration.span())?; - self.set_variable(&self.to_absolute_path(&[iteration.variable.name]), start); + self.set_variable( + &self.to_absolute_path(&[iteration.variable.name]), + start, + iteration.type_.clone(), + ); self.frames.push(Frame { step: 0, element: Element::Block { block: iteration.block.clone(), function_body: false }, @@ -896,6 +989,8 @@ impl Cursor { }; } + let mut is_some = true; + if let Some(value) = match expression { Expression::ArrayAccess(array) if step == 0 => { push!()(&array.index, &None); @@ -1252,9 +1347,17 @@ impl Cursor { } Expression::Err(_) => todo!(), Expression::Path(path) if step == 0 => { - Some(self.lookup(&self.to_absolute_path(&path.as_symbols())).expect_tc(path.span())?) + Some(self.lookup(&self.to_absolute_path(&path.as_symbols())).expect_tc(path.span())?.0) + } + Expression::Literal(literal) if step == 0 => { + if literal.variant == leo_ast::LiteralVariant::None { + is_some = false; + let Type::Optional(_) = expected_ty.as_ref().expect_tc(literal.span)? else { tc_fail!() }; + Some(self.resolve_none(expected_ty.as_ref().unwrap())?) + } else { + Some(literal_to_value(literal, expected_ty)?) + } } - Expression::Literal(literal) if step == 0 => Some(literal_to_value(literal, expected_ty)?), Expression::Locator(_locator) => todo!(), Expression::Struct(struct_) if step == 0 => { let members = @@ -1337,6 +1440,15 @@ impl Cursor { } { assert_eq!(self.frames.len(), len); self.frames.pop(); + + let value = match (expected_ty, &value.contents, is_some) { + (Some(Type::Optional(otp)), ValueVariants::Svm(snarkvm::prelude::Value::Plaintext(_)), true) => { + let wrapper = self.register_optional(&otp.inner); + self.wrap_optional_value(value, wrapper, true) + } + _ => value, + }; + self.values.push(value); Ok(true) } else { @@ -1345,6 +1457,78 @@ impl Cursor { } } + // Try to convert the none value to a reasonable typed value that should be represented + // as the in memory value for the snarkVM. + fn resolve_none(&mut self, expected_ty: &Type) -> Result { + match expected_ty { + Type::Address + | Type::Boolean + | Type::Field + | Type::Group + | Type::Scalar + | Type::Signature + | Type::Integer(_) => { + let struct_lookup = |symbols: &[Symbol]| { + self.structs.get(symbols).unwrap().iter().map(|st| (*st.0, st.1.clone())).collect::>() + }; + let Expression::Literal(lit) = + zero_value_expression(expected_ty, Span::dummy(), &NodeBuilder::default(), &struct_lookup).unwrap() + else { + panic!("Can't happen") + }; + literal_to_value(&lit, &Some(expected_ty.clone())) + } + + Type::Array(array_type) => { + let val = self.resolve_none(array_type.element_type())?; + Ok(Value::make_array(std::iter::repeat_n(val, array_type.length.as_u32().unwrap() as usize))) + } + + Type::Composite(composite_type) => { + let contents: Result> = self + .structs + .get(&composite_type.path.as_symbols()) + .unwrap() + .clone() + .into_iter() + .map(|(s, t)| Ok((s, self.resolve_none(&t)?))) + .collect(); + Ok(Value::make_struct( + contents?.into_iter(), + self.current_program().unwrap_or_default(), + composite_type.path.as_symbols(), + )) + } + + Type::Optional(optional_type) => { + let val = self.resolve_none(&optional_type.inner)?; + let opt = self.register_optional(&optional_type.inner); + Ok(self.wrap_optional_value(val, opt, false)) + } + + _ => todo!(), + } + } + + fn register_optional(&mut self, type_: &Type) -> Symbol { + let is_some = Symbol::intern("is_some"); + let val = Symbol::intern("val"); + let sym = make_optional_struct_symbol(type_); + self.structs.insert(vec![sym], IndexMap::from([(is_some, Type::Boolean), (val, type_.clone())])); + sym + } + + fn wrap_optional_value(&self, value: Value, path: Symbol, is_some: bool) -> Value { + let is_some_symbol = Symbol::intern("is_some"); + let val_symbol = Symbol::intern("val"); + + Value::make_struct( + vec![(is_some_symbol, is_some.into()), (val_symbol, value)].into_iter(), + self.current_program().unwrap_or_default(), + vec![path], + ) + } + /// Execute one step of the current element. /// /// Many Leo constructs require multiple steps. For instance, when executing a conditional, @@ -1406,9 +1590,10 @@ impl Cursor { true, // is_async HashMap::new(), ); - let param_names = function.input.iter().map(|input| input.identifier.name); - for (name, value) in param_names.zip(values) { - self.set_variable(&self.to_absolute_path(&[name]), value); + let param_names = + function.input.iter().map(|input| (input.identifier.name, input.type_.clone())); + for ((name, type_), value) in param_names.zip(values) { + self.set_variable(&self.to_absolute_path(&[name]), value, Some(type_)); } self.frames.last_mut().unwrap().step = 1; self.frames.push(Frame { @@ -1512,10 +1697,10 @@ impl Cursor { let param_names = function .const_parameters .iter() - .map(|param| param.identifier.name) - .chain(function.input.iter().map(|input| input.identifier.name)); - for (name, value) in param_names.zip(arguments) { - self.set_variable(&self.to_absolute_path(&[name]), value); + .map(|param| (param.identifier.name, param.type_.clone())) + .chain(function.input.iter().map(|input| (input.identifier.name, input.type_.clone()))); + for ((name, type_), value) in param_names.zip(arguments) { + self.set_variable(&self.to_absolute_path(&[name]), value, Some(type_)); } self.frames.push(Frame { step: 0, diff --git a/interpreter/src/interpreter.rs b/interpreter/src/interpreter.rs index a817cc15219..303e912c764 100644 --- a/interpreter/src/interpreter.rs +++ b/interpreter/src/interpreter.rs @@ -169,7 +169,9 @@ impl Interpreter { }); cursor.over()?; let value = cursor.values.pop().unwrap(); - cursor.globals.insert(Location::new(program, vec![*name]), value); + cursor + .globals + .insert(Location::new(program, vec![*name]), (value, const_declaration.type_.clone())); } } @@ -209,7 +211,10 @@ impl Interpreter { }); cursor.over()?; let value = cursor.values.pop().unwrap(); - cursor.globals.insert(Location::new(program, to_absolute_path(*name)), value); + cursor.globals.insert( + Location::new(program, to_absolute_path(*name)), + (value, const_declaration.type_.clone()), + ); } } } @@ -374,8 +379,8 @@ impl Interpreter { source_file.absolute_start, self.network, ) - .map_err(|_e| { - LeoError::InterpreterHalt(InterpreterHalt::new("failed to parse statement".into())) + .map_err(|e| { + LeoError::InterpreterHalt(InterpreterHalt::new(format!("failed to parse statement: {e}"))) })?; // The spans of the code the user wrote at the REPL are meaningless, so get rid of them. self.cursor.frames.push(Frame {