diff --git a/compiler/ast/src/common/identifier.rs b/compiler/ast/src/common/identifier.rs index 3a7f13b8628..138789e13d4 100644 --- a/compiler/ast/src/common/identifier.rs +++ b/compiler/ast/src/common/identifier.rs @@ -14,7 +14,7 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{Expression, Node, NodeID, Path, simple_node_impl}; +use crate::{Expression, Node, NodeID, Path, PathKind, simple_node_impl}; use leo_span::{Span, Symbol}; @@ -97,10 +97,10 @@ impl From for Expression { } // Converts an `Identifier` to a `Path` -// Note that this sets the `absolute_path` field in `Path` to `None` and `is_absolute` to `false`. +// Note that this sets the `absolute_path` field in `Path` to `None` and `kind` to `PathKind::Relative`. // It's up to the caller of this method to figure out what to do with `absolute_path`. impl From for Path { fn from(value: Identifier) -> Self { - Path::new(vec![], value, false, None, value.span, value.id) + Path::new(vec![], value, PathKind::Relative, None, value.span, value.id) } } diff --git a/compiler/ast/src/common/path.rs b/compiler/ast/src/common/path.rs index 3b403c5c6e6..db2ac1f7303 100644 --- a/compiler/ast/src/common/path.rs +++ b/compiler/ast/src/common/path.rs @@ -22,8 +22,19 @@ use itertools::Itertools; use serde::{Deserialize, Serialize}; use std::{fmt, hash::Hash}; +/// The kind of a path: local absolute, local relative, or external (i.e. from another program). +#[derive(Clone, Debug, Hash, Eq, PartialEq, Serialize, Deserialize)] +pub enum PathKind { + /// A path like `::foo::bar`. + Absolute, + /// A path like `foo::bar`. + Relative, + /// A path like `program.aleo::foo::bar`. The `Identifier` here is the name of the program without `.aleo`. + External(Identifier), +} + /// A Path in a program. -#[derive(Clone, Default, Hash, Eq, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Hash, Eq, PartialEq, Serialize, Deserialize)] pub struct Path { /// The qualifying namespace segments written by the user, excluding the item itself. /// e.g., in `foo::bar::baz`, this would be `[foo, bar]`. @@ -33,7 +44,7 @@ pub struct Path { identifier: Identifier, /// Is this path an absolute path? e.g. `::foo::bar::baz`. - is_absolute: bool, + kind: PathKind, /// The fully resolved path. We may not know this until the pass PathResolution pass runs. /// For path that refer to global items (structs, consts, functions), `absolute_path` is @@ -54,19 +65,19 @@ impl Path { /// /// - `qualifier`: The namespace segments (e.g., `foo::bar` in `foo::bar::baz`). /// - `identifier`: The final item in the path (e.g., `baz`). - /// - `is_absolute`: Whether the path is absolute (starts with `::`). + /// - `kind`: The kind of the path (absolute, relative, or external). /// - `absolute_path`: Optionally, the fully resolved symbolic path. /// - `span`: The source code span for this path. /// - `id`: The node ID. pub fn new( qualifier: Vec, identifier: Identifier, - is_absolute: bool, + kind: PathKind, absolute_path: Option>, span: Span, id: NodeID, ) -> Self { - Self { qualifier, identifier, is_absolute, absolute_path, span, id } + Self { qualifier, identifier, kind, absolute_path, span, id } } /// Returns the final identifier of the path (e.g., `baz` in `foo::bar::baz`). @@ -79,9 +90,22 @@ impl Path { self.qualifier.as_slice() } + /// Returns the program name if this is an external path, or `None` otherwise (implying current program). + pub fn program(&self) -> Option { + match &self.kind { + PathKind::External(program) => Some(program.name), + _ => None, + } + } + /// Returns `true` if the path is absolute (i.e., starts with `::`). pub fn is_absolute(&self) -> bool { - self.is_absolute + matches!(self.kind, PathKind::Absolute) + } + + /// Returns `true` if the path is external (i.e., starts with `.aleo::`). + pub fn is_external(&self) -> bool { + matches!(self.kind, PathKind::External(_)) } /// Returns a `Vec` representing the full symbolic path: @@ -95,7 +119,10 @@ impl Path { /// Returns an optional vector of `Symbol`s representing the resolved absolute path, /// or `None` if resolution has not yet occurred. pub fn try_absolute_path(&self) -> Option> { - if self.is_absolute { Some(self.as_symbols()) } else { self.absolute_path.clone() } + match self.kind { + PathKind::Absolute | PathKind::External(_) => Some(self.as_symbols()), + PathKind::Relative => self.absolute_path.clone(), + } } /// Returns a vector of `Symbol`s representing the resolved absolute path. @@ -103,19 +130,25 @@ impl Path { /// If the path is not an absolute path, this method panics if the absolute path has not been resolved yet. /// For relative paths, this is expected to be called only after path resolution has occurred. pub fn absolute_path(&self) -> Vec { - if self.is_absolute { - self.as_symbols() - } else { - self.absolute_path.as_ref().expect("absolute path must be known at this stage").to_vec() + match self.kind { + PathKind::Absolute | PathKind::External(_) => self.as_symbols(), + PathKind::Relative => { + self.absolute_path.as_ref().expect("absolute path must be known at this stage").to_vec() + } } } - /// Converts this `Path` into an absolute path by setting its `is_absolute` flag to `true`. + /// Converts this `Path` into an absolute path by setting its `kind` to `Absolute`. /// /// This does not alter the qualifier or identifier, nor does it compute or modify /// the resolved `absolute_path`. pub fn into_absolute(mut self) -> Self { - self.is_absolute = true; + self.kind = PathKind::Absolute; + self + } + + pub fn into_external(mut self, program: Identifier) -> Self { + self.kind = PathKind::External(program); self } @@ -160,9 +193,12 @@ impl Path { impl fmt::Display for Path { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if self.is_absolute { - write!(f, "::")?; + match &self.kind { + PathKind::Absolute => write!(f, "::")?, + PathKind::External(program) => write!(f, "{}.aleo::", program)?, + PathKind::Relative => {} } + if self.qualifier.is_empty() { write!(f, "{}", self.identifier) } else { @@ -173,15 +209,15 @@ impl fmt::Display for Path { impl fmt::Debug for Path { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // Print user path (Display impl) + // same as Display write!(f, "{self}")?; - // Print resolved absolute path if available + // append the resolved absolute path, if any if let Some(abs_path) = &self.absolute_path { - write!(f, "(::{})", abs_path.iter().format("::")) - } else { - write!(f, "()") + write!(f, " [abs=::{}]", abs_path.iter().format("::"))?; } + + Ok(()) } } diff --git a/compiler/ast/src/expressions/mod.rs b/compiler/ast/src/expressions/mod.rs index 3c15365ad3d..14176383a50 100644 --- a/compiler/ast/src/expressions/mod.rs +++ b/compiler/ast/src/expressions/mod.rs @@ -14,7 +14,7 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{Identifier, IntegerType, Node, NodeBuilder, NodeID, Path, Type}; +use crate::{Identifier, IntegerType, Location, Node, NodeBuilder, NodeID, Path, Type}; use leo_span::{Span, Symbol}; use serde::{Deserialize, Serialize}; @@ -354,8 +354,9 @@ impl Expression { pub fn zero( ty: &Type, span: Span, + program: Symbol, node_builder: &NodeBuilder, - struct_lookup: &dyn Fn(&[Symbol]) -> Vec<(Symbol, Type)>, + struct_lookup: &dyn Fn(&Location) -> Vec<(Symbol, Type)>, ) -> Option { let id = node_builder.next_id(); @@ -399,13 +400,13 @@ impl Expression { // Structs (composite types) Type::Composite(composite_type) => { let path = &composite_type.path; - let members = struct_lookup(&path.absolute_path()); + let members = struct_lookup(&Location::new(program, path.absolute_path())); let struct_members = members .into_iter() .map(|(symbol, member_type)| { let member_id = node_builder.next_id(); - let zero_expr = Self::zero(&member_type, span, node_builder, struct_lookup)?; + let zero_expr = Self::zero(&member_type, span, program, node_builder, struct_lookup)?; Some(StructVariableInitializer { span, @@ -429,7 +430,7 @@ impl Expression { Type::Array(array_type) => { let element_ty = &array_type.element_type; - let element_expr = Self::zero(element_ty, span, node_builder, struct_lookup)?; + let element_expr = Self::zero(element_ty, span, program, node_builder, struct_lookup)?; Some(Expression::Repeat( RepeatExpression { span, id, expr: element_expr, count: *array_type.length.clone() }.into(), diff --git a/compiler/ast/src/interpreter_value/value.rs b/compiler/ast/src/interpreter_value/value.rs index 6d5103ecfd7..f309cdf14b6 100644 --- a/compiler/ast/src/interpreter_value/value.rs +++ b/compiler/ast/src/interpreter_value/value.rs @@ -826,9 +826,10 @@ impl Value { pub fn to_expression( &self, span: Span, + current_program: &Symbol, node_builder: &NodeBuilder, ty: &Type, - struct_lookup: &dyn Fn(&[Symbol]) -> Vec<(Symbol, Type)>, + struct_lookup: &dyn Fn(&Location) -> Vec<(Symbol, Type)>, ) -> Option { use crate::{Literal, TupleExpression, UnitExpression}; @@ -850,7 +851,7 @@ impl Value { elements: vec .iter() .zip(tuple_type.elements()) - .map(|(val, ty)| val.to_expression(span, node_builder, ty, struct_lookup)) + .map(|(val, ty)| val.to_expression(span, current_program, node_builder, ty, struct_lookup)) .collect::>>()?, } .into() @@ -858,7 +859,7 @@ impl Value { ValueVariants::Unsuffixed(s) => Literal::unsuffixed(s.clone(), span, id).into(), ValueVariants::Svm(value) => match value { SvmValueParam::Plaintext(plaintext) => { - plaintext_to_expression(plaintext, span, node_builder, ty, &struct_lookup)? + plaintext_to_expression(plaintext, span, current_program, node_builder, ty, &struct_lookup)? } SvmValueParam::Record(..) => return None, SvmValueParam::Future(..) => return None, @@ -875,9 +876,10 @@ impl Value { fn plaintext_to_expression( plaintext: &SvmPlaintext, span: Span, + current_program: &Symbol, node_builder: &NodeBuilder, ty: &Type, - struct_lookup: &dyn Fn(&[Symbol]) -> Vec<(Symbol, Type)>, + struct_lookup: &dyn Fn(&Location) -> Vec<(Symbol, Type)>, ) -> Option { use crate::{ArrayExpression, Identifier, IntegerType, Literal, StructExpression, StructVariableInitializer}; @@ -924,8 +926,10 @@ fn plaintext_to_expression( let Type::Composite(composite_type) = ty else { return None; }; - let symbols = composite_type.path.as_symbols(); - let iter_members = struct_lookup(&symbols); + let program = composite_type.program.unwrap_or(*current_program); + let location = Location::new(program, composite_type.path.absolute_path().to_vec()); + let iter_members = struct_lookup(&location); + StructExpression { span, id, @@ -945,6 +949,7 @@ fn plaintext_to_expression( expression: Some(plaintext_to_expression( index_map.get(&svm_identifier)?, span, + current_program, node_builder, &ty, &struct_lookup, @@ -964,7 +969,16 @@ fn plaintext_to_expression( id, elements: vec .iter() - .map(|pt| plaintext_to_expression(pt, span, node_builder, &array_ty.element_type, &struct_lookup)) + .map(|pt| { + plaintext_to_expression( + pt, + span, + current_program, + node_builder, + &array_ty.element_type, + &struct_lookup, + ) + }) .collect::>>()?, } .into() diff --git a/compiler/ast/src/passes/reconstructor.rs b/compiler/ast/src/passes/reconstructor.rs index fe7bce82035..2efd540d4f1 100644 --- a/compiler/ast/src/passes/reconstructor.rs +++ b/compiler/ast/src/passes/reconstructor.rs @@ -578,6 +578,7 @@ pub trait ProgramReconstructor: AstReconstructor { .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1))) .collect(), stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(), + programs: input.programs.into_iter().map(|(id, program)| (id, self.reconstruct_program(program))).collect(), modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(), program_scopes, } diff --git a/compiler/ast/src/passes/visitor.rs b/compiler/ast/src/passes/visitor.rs index e5720df6f93..03f41e480f6 100644 --- a/compiler/ast/src/passes/visitor.rs +++ b/compiler/ast/src/passes/visitor.rs @@ -334,6 +334,7 @@ pub trait ProgramVisitor: AstVisitor { input.modules.values().for_each(|module| self.visit_module(module)); input.imports.values().for_each(|import| self.visit_import(&import.0)); input.stubs.values().for_each(|stub| self.visit_stub(stub)); + input.programs.values().for_each(|program| self.visit_program(program)); } fn visit_program_scope(&mut self, input: &ProgramScope) { diff --git a/compiler/ast/src/program/mod.rs b/compiler/ast/src/program/mod.rs index 88a8d649a28..624ee061b0e 100644 --- a/compiler/ast/src/program/mod.rs +++ b/compiler/ast/src/program/mod.rs @@ -38,6 +38,8 @@ pub struct Program { pub imports: IndexMap, /// A map from program stub names to program stub scopes. pub stubs: IndexMap, + /// A map from program names to program definitions. + pub programs: IndexMap, /// A map from program names to program scopes. pub program_scopes: IndexMap, } @@ -53,6 +55,9 @@ impl fmt::Display for Program { for (_, stub) in self.stubs.iter() { writeln!(f, "{stub}")?; } + for (_, program) in self.programs.iter() { + writeln!(f, "{program}")?; + } for (_, program_scope) in self.program_scopes.iter() { writeln!(f, "{program_scope}")?; } @@ -67,6 +72,7 @@ impl Default for Program { modules: IndexMap::new(), imports: IndexMap::new(), stubs: IndexMap::new(), + programs: IndexMap::new(), program_scopes: IndexMap::new(), } } diff --git a/compiler/ast/src/types/struct_type.rs b/compiler/ast/src/types/composite.rs similarity index 100% rename from compiler/ast/src/types/struct_type.rs rename to compiler/ast/src/types/composite.rs diff --git a/compiler/ast/src/types/mod.rs b/compiler/ast/src/types/mod.rs index 7ac00f53073..e694eebf477 100644 --- a/compiler/ast/src/types/mod.rs +++ b/compiler/ast/src/types/mod.rs @@ -20,6 +20,9 @@ pub use array::*; mod core_constant; pub use core_constant::*; +mod composite; +pub use composite::*; + mod future; pub use future::*; @@ -32,9 +35,6 @@ pub use optional::*; mod mapping; pub use mapping::*; -mod struct_type; -pub use struct_type::*; - mod tuple; pub use tuple::*; diff --git a/compiler/ast/src/types/type_.rs b/compiler/ast/src/types/type_.rs index 4431ebc22cd..d73bc5ddb96 100644 --- a/compiler/ast/src/types/type_.rs +++ b/compiler/ast/src/types/type_.rs @@ -138,7 +138,7 @@ impl Type { } // Two composite types are the same if their programs and their _absolute_ paths match. - (left.program == right.program) + (left.path.program() == right.path.program()) && match (&left.path.try_absolute_path(), &right.path.try_absolute_path()) { (Some(l), Some(r)) => l == r, _ => false, diff --git a/compiler/compiler/src/compiler.rs b/compiler/compiler/src/compiler.rs index d58310cf2a9..758779d8ef5 100644 --- a/compiler/compiler/src/compiler.rs +++ b/compiler/compiler/src/compiler.rs @@ -20,8 +20,8 @@ use crate::{AstSnapshots, CompilerOptions}; -pub use leo_ast::Ast; -use leo_ast::{NetworkName, Stub}; +pub use leo_ast::{Ast, NodeBuilder}; +use leo_ast::{NetworkName, Program, Stub}; use leo_errors::{CompilerError, Handler, Result}; use leo_passes::*; use leo_span::{Symbol, source_map::FileName, with_session_globals}; @@ -30,6 +30,7 @@ use std::{ ffi::OsStr, fs, path::{Path, PathBuf}, + rc::Rc, }; use indexmap::{IndexMap, IndexSet}; @@ -47,6 +48,8 @@ pub struct Compiler { state: CompilerState, /// The stubs for imported programs. import_stubs: IndexMap, + /// The stubs for imported programs. + imported_programs: IndexMap, /// How many statements were in the AST before DCE? pub statements_before_dce: u32, /// How many statements were in the AST after DCE? @@ -54,7 +57,12 @@ pub struct Compiler { } impl Compiler { - pub fn parse(&mut self, source: &str, filename: FileName, modules: &[(&str, FileName)]) -> Result<()> { + pub fn parse( + &mut self, + source: &str, + filename: FileName, + modules: &[(&str, FileName)], + ) -> Result { // Register the source in the source map. let source_file = with_session_globals(|s| s.source_map.new_source(source, filename.clone())); @@ -100,7 +108,7 @@ impl Compiler { self.write_ast("initial.ast")?; } - Ok(()) + Ok(self.state.ast.ast.clone()) } /// Returns a new Leo compiler. @@ -109,17 +117,26 @@ impl Compiler { expected_program_name: Option, is_test: bool, handler: Handler, + node_builder: Rc, output_directory: PathBuf, compiler_options: Option, import_stubs: IndexMap, + imported_programs: IndexMap, network: NetworkName, ) -> Self { Self { - state: CompilerState { handler, is_test, network, ..Default::default() }, + state: CompilerState { + handler, + is_test, + network, + node_builder: Rc::clone(&node_builder), + ..Default::default() + }, output_directory, program_name: expected_program_name, compiler_options: compiler_options.unwrap_or_default(), import_stubs, + imported_programs, statements_before_dce: 0, statements_after_dce: 0, } @@ -143,6 +160,7 @@ impl Compiler { /// Runs the compiler stages. pub fn intermediate_passes(&mut self) -> Result<()> { + println!("{}", self.state.ast.ast); let type_checking_config = TypeCheckingInput::new(self.state.network); self.do_pass::(())?; @@ -188,6 +206,8 @@ impl Compiler { self.statements_before_dce = output.statements_before; self.statements_after_dce = output.statements_after; + println!("{}", self.state.ast.ast); + Ok(()) } @@ -209,7 +229,9 @@ impl Compiler { // Parse the program. self.parse(source, filename, modules)?; // Merge the stubs into the AST. - self.add_import_stubs()?; + // self.add_import_stubs()?; + // Merge the programs into the AST. + self.add_import_programs()?; // Run the intermediate compiler stages. self.intermediate_passes()?; // Run code generation. @@ -277,6 +299,46 @@ impl Compiler { self.compile(&source, FileName::Real(entry_file_path.as_ref().into()), &modules) } + pub fn parse_from_directory( + &mut self, + entry_file_path: impl AsRef, + source_directory: impl AsRef, + ) -> Result { + // Read the contents of the main source file. + let source = fs::read_to_string(&entry_file_path) + .map_err(|e| CompilerError::file_read_error(entry_file_path.as_ref().display().to_string(), e))?; + + // Walk all files under source_directory recursively, excluding the main source file itself. + let files = WalkDir::new(source_directory) + .into_iter() + .filter_map(Result::ok) + .filter(|e| { + e.file_type().is_file() + && e.path() != entry_file_path.as_ref() + && e.path().extension() == Some(OsStr::new("leo")) + }) + .collect::>(); + + let mut module_sources = Vec::new(); // Keep Strings alive for valid borrowing + let mut modules = Vec::new(); // Parsed (source, filename) tuples for compilation + + // Read all module files and store their contents + for file in &files { + let source = fs::read_to_string(file.path()) + .map_err(|e| CompilerError::file_read_error(file.path().display().to_string(), e))?; + module_sources.push(source); // Keep the String alive + } + + // Create tuples of (&str, FileName) for the compiler + for (i, file) in files.iter().enumerate() { + let source = &module_sources[i]; // Borrow from the alive String + modules.push((&source[..], FileName::Real(file.path().into()))); + } + + // Compile the main source along with all collected modules + self.parse(&source, FileName::Real(entry_file_path.as_ref().into()), &modules) + } + /// Writes the AST to a JSON file. fn write_ast_to_json(&self, file_suffix: &str) -> Result<()> { // Remove `Span`s if they are not enabled. @@ -338,4 +400,39 @@ impl Compiler { .collect(); Ok(()) } + + /// Merge the imported programs which are dependencies of the current program into the AST + /// in topological order. + pub fn add_import_programs(&mut self) -> Result<()> { + let mut explored = IndexSet::::new(); + let mut to_explore: Vec = self.state.ast.ast.imports.keys().cloned().collect(); + + while let Some(import) = to_explore.pop() { + explored.insert(import); + if let Some(program) = self.imported_programs.get(&import) { + for new_import_id in program.imports.iter() { + if !explored.contains(new_import_id.0) { + to_explore.push(*new_import_id.0); + } + } + } else { + return Err(CompilerError::imported_program_not_found( + self.program_name.as_ref().unwrap(), + import, + self.state.ast.ast.imports[&import].1, + ) + .into()); + } + } + + // Iterate in the order of `import_programs` to make sure they + // stay topologically sorted. + self.state.ast.ast.programs = self + .imported_programs + .iter() + .filter(|(symbol, _program)| explored.contains(*symbol)) + .map(|(symbol, program)| (*symbol, program.clone())) + .collect(); + Ok(()) + } } diff --git a/compiler/compiler/src/test_compiler.rs b/compiler/compiler/src/test_compiler.rs index 3ad4444188d..c112ec7a64b 100644 --- a/compiler/compiler/src/test_compiler.rs +++ b/compiler/compiler/src/test_compiler.rs @@ -34,12 +34,29 @@ fn run_test(test: &str, handler: &Handler) -> Result { let mut import_stubs = IndexMap::new(); + let mut import_programs = IndexMap::new(); + let mut bytecodes = Vec::::new(); + let node_builder: leo_ast::NodeBuilder = Default::default(); + let node_builder = std::rc::Rc::new(node_builder); + // Compile each source file separately. for source in test.split(super::test_utils::PROGRAM_DELIMITER) { - let (bytecode, program_name) = - handler.extend_if_error(super::test_utils::whole_compile(source, handler, import_stubs.clone()))?; + let (parsed, _) = handler.extend_if_error(super::test_utils::parse( + source, + handler, + &node_builder, + import_stubs.clone(), + import_programs.clone(), + ))?; + let (bytecode, program_name) = handler.extend_if_error(super::test_utils::whole_compile( + source, + handler, + &node_builder, + import_stubs.clone(), + import_programs.clone(), + ))?; // Parse the bytecode as an Aleo program. // Note that this function checks that the bytecode is well-formed. @@ -53,6 +70,7 @@ fn run_test(test: &str, handler: &Handler) -> Result { let stub = handler .extend_if_error(disassemble_from_str::(&program_name, &bytecode).map_err(|err| err.into()))?; import_stubs.insert(Symbol::intern(&program_name), stub); + import_programs.insert(Symbol::intern(&program_name), parsed); // Only error out if there are errors. Warnings are okay but we still want to print them later. if handler.err_count() != 0 { diff --git a/compiler/compiler/src/test_execution.rs b/compiler/compiler/src/test_execution.rs index ef6ae5df52e..c73c250294c 100644 --- a/compiler/compiler/src/test_execution.rs +++ b/compiler/compiler/src/test_execution.rs @@ -45,13 +45,23 @@ impl Default for Config { fn execution_run_test(config: &Config, cases: &[run_with_ledger::Case], handler: &Handler) -> Result { let mut import_stubs = IndexMap::new(); + let mut import_programs = IndexMap::new(); let mut ledger_config = run_with_ledger::Config { seed: config.seed, start_height: config.start_height, programs: Vec::new() }; + let node_builder: leo_ast::NodeBuilder = Default::default(); + let node_builder = std::rc::Rc::new(node_builder); + // Compile each source file. for source in &config.sources { - let (bytecode, name) = super::test_utils::whole_compile(source, handler, import_stubs.clone())?; + let (bytecode, name) = super::test_utils::whole_compile( + source, + handler, + &node_builder, + import_stubs.clone(), + import_programs.clone(), + )?; let stub = disassemble_from_str::(&name, &bytecode)?; import_stubs.insert(Symbol::intern(&name), stub); diff --git a/compiler/compiler/src/test_utils.rs b/compiler/compiler/src/test_utils.rs index b9c9d5d4918..3aeadd8bbce 100644 --- a/compiler/compiler/src/test_utils.rs +++ b/compiler/compiler/src/test_utils.rs @@ -16,11 +16,11 @@ use crate::Compiler; -use leo_ast::{NetworkName, Stub}; +use leo_ast::{NetworkName, NodeBuilder, Program, Stub}; use leo_errors::{Handler, LeoError}; use leo_span::{Symbol, source_map::FileName}; -use std::path::PathBuf; +use std::{path::PathBuf, rc::Rc}; use indexmap::IndexMap; @@ -36,15 +36,19 @@ pub const MODULE_DELIMITER: &str = "// --- Next Module:"; pub fn whole_compile( source: &str, handler: &Handler, + node_builder: &Rc, import_stubs: IndexMap, + import_programs: IndexMap, ) -> Result<(String, String), LeoError> { let mut compiler = Compiler::new( None, /* is_test */ false, handler.clone(), + std::rc::Rc::clone(&node_builder), "/fakedirectory-wont-use".into(), None, import_stubs, + import_programs, NetworkName::TestnetV0, ); @@ -97,3 +101,72 @@ pub fn whole_compile( Ok((bytecode, compiler.program_name.unwrap())) } + +pub fn parse( + source: &str, + handler: &Handler, + node_builder: &Rc, + import_stubs: IndexMap, + import_programs: IndexMap, +) -> Result<(Program, String), LeoError> { + let mut compiler = Compiler::new( + None, + /* is_test */ false, + handler.clone(), + std::rc::Rc::clone(&node_builder), + "/fakedirectory-wont-use".into(), + None, + import_stubs, + import_programs, + NetworkName::TestnetV0, + ); + + if !source.contains(MODULE_DELIMITER) { + // Fast path: no modules + let filename = FileName::Custom("compiler-test".into()); + let parsed = compiler.parse(source, filename.clone(), &Vec::new())?; + return Ok((parsed, compiler.program_name.unwrap())); + } + + let mut main_source = String::new(); + let mut modules: Vec<(String, PathBuf)> = Vec::new(); + + let mut current_module_path: Option = None; + let mut current_module_source = String::new(); + + for line in source.lines() { + if let Some(rest) = line.strip_prefix(MODULE_DELIMITER) { + // Save previous block + if let Some(path) = current_module_path.take() { + modules.push((current_module_source.clone(), path)); + current_module_source.clear(); + } else { + main_source = current_module_source.clone(); + current_module_source.clear(); + } + + // Start new module + let trimmed_path = rest.trim().trim_end_matches(" --- //"); + current_module_path = Some(PathBuf::from(trimmed_path)); + } else { + current_module_source.push_str(line); + current_module_source.push('\n'); + } + } + + // Push the last module or main + if let Some(path) = current_module_path { + modules.push((current_module_source.clone(), path)); + } else { + main_source = current_module_source; + } + + // Prepare module references for compiler + let module_refs: Vec<(&str, FileName)> = + modules.iter().map(|(src, path)| (src.as_str(), FileName::Custom(path.to_string_lossy().into()))).collect(); + + let filename = FileName::Custom("compiler-test".into()); + let parsed = compiler.parse(&main_source, filename, &module_refs)?; + + Ok((parsed, compiler.program_name.unwrap())) +} diff --git a/compiler/parser-lossless/src/grammar.lalrpop b/compiler/parser-lossless/src/grammar.lalrpop index 63ebd2a8734..343f81fef4f 100644 --- a/compiler/parser-lossless/src/grammar.lalrpop +++ b/compiler/parser-lossless/src/grammar.lalrpop @@ -194,9 +194,9 @@ ExprStruct: SyntaxNode<'a> = { // Struct initializer. > > => { // For now the initializer name can't be external. - if name.text.contains(".aleo") { - handler.emit_err(ParserError::cannot_define_external_record(name.span)); - } + // if name.text.contains(".aleo") { + // handler.emit_err(ParserError::cannot_define_external_record(name.span)); + // } SyntaxNode::new(ExpressionKind::Struct, iter::once(name).chain(c).chain([l]).chain(inits).chain([r])) }, } diff --git a/compiler/parser-lossless/src/tokens.rs b/compiler/parser-lossless/src/tokens.rs index 2fe99bcb23b..9cd596b8bf8 100644 --- a/compiler/parser-lossless/src/tokens.rs +++ b/compiler/parser-lossless/src/tokens.rs @@ -29,7 +29,8 @@ pub enum IdVariants { fn id_variant(lex: &mut logos::Lexer) -> IdVariants { // Use LazyLock to not recompile these regexes every time. static REGEX_LOCATOR: LazyLock = - LazyLock::new(|| regex::Regex::new(r"^\.aleo/[a-zA-Z][a-zA-Z0-9_]*").unwrap()); + LazyLock::new(|| regex::Regex::new(r"^\.aleo(?:::[a-zA-Z][a-zA-Z0-9_]*)+").unwrap()); + static REGEX_PROGRAM_ID: LazyLock = LazyLock::new(|| regex::Regex::new(r"^\.aleo\b").unwrap()); static REGEX_PATH: LazyLock = LazyLock::new(|| regex::Regex::new(r"^(?:::[a-zA-Z][a-zA-Z0-9_]*)+").unwrap()); diff --git a/compiler/parser/src/conversions.rs b/compiler/parser/src/conversions.rs index b7db7121e14..148f4b32a69 100644 --- a/compiler/parser/src/conversions.rs +++ b/compiler/parser/src/conversions.rs @@ -78,20 +78,55 @@ fn to_type(node: &SyntaxNode<'_>, builder: &NodeBuilder, handler: &Handler) -> R TypeKind::Boolean => leo_ast::Type::Boolean, TypeKind::Composite => { let name = &node.children[0]; - if let Some((program, name_str)) = name.text.split_once(".aleo/") { + + /*let function = if let Some((first, second)) = name.text.split_once(".aleo::") { + // This is a locator. + let symbol = Symbol::intern(second); + let lo = node.span.lo + first.len() as u32 + ".aleo/".len() as u32; + let second_span = Span { lo, hi: lo + second.len() as u32 }; + let identifier = leo_ast::Identifier { name: symbol, span: second_span, id: builder.next_id() }; + let program = leo_ast::Identifier { + name: Symbol::intern(first), + span: Span { lo: node.span.lo, hi: node.span.lo + first.len() as u32 + "aleo/".len() as u32 }, + id: builder.next_id(), + }; + leo_ast::Path::new( + Vec::new(), + identifier, + leo_ast::PathKind::External(program), + None, + span, + builder.next_id(), + ) + }*/ + + dbg!(&name); + if let Some((first, name_str)) = name.text.split_once(".aleo::") { // This is a locator. let name_id = leo_ast::Identifier { name: Symbol::intern(name_str), span: leo_span::Span { - lo: name.span.lo + program.len() as u32 + 5, + lo: name.span.lo + first.len() as u32 + 5, hi: name.span.lo + name.text.len() as u32, }, id: builder.next_id(), }; + let program = leo_ast::Identifier { + name: Symbol::intern(first), + span: Span { lo: node.span.lo, hi: node.span.lo + first.len() as u32 + "aleo/".len() as u32 }, + id: builder.next_id(), + }; leo_ast::CompositeType { - path: leo_ast::Path::new(Vec::new(), name_id, false, None, name_id.span, builder.next_id()), + path: leo_ast::Path::new( + Vec::new(), + name_id, + leo_ast::PathKind::External(program), + None, + name_id.span, + builder.next_id(), + ), const_arguments: Vec::new(), - program: Some(Symbol::intern(program)), + program: Some(Symbol::intern(first)), } .into() } else { @@ -107,7 +142,14 @@ fn to_type(node: &SyntaxNode<'_>, builder: &NodeBuilder, handler: &Handler) -> R .collect::>>()?; } let identifier = path_components.pop().unwrap(); - let path = leo_ast::Path::new(path_components, identifier, false, None, name.span, builder.next_id()); + let path = leo_ast::Path::new( + path_components, + identifier, + leo_ast::PathKind::Relative, + None, + name.span, + builder.next_id(), + ); leo_ast::CompositeType { path, const_arguments, program: None }.into() } } @@ -519,20 +561,37 @@ pub fn to_expression(node: &SyntaxNode<'_>, builder: &NodeBuilder, handler: &Han .map(|child| to_expression(child, builder, handler)) .collect::>>()?; - let (function, program) = if let Some((first, second)) = name.text.split_once(".aleo/") { + let function = if let Some((first, second)) = name.text.split_once(".aleo::") { // This is a locator. let symbol = Symbol::intern(second); let lo = node.span.lo + first.len() as u32 + ".aleo/".len() as u32; let second_span = Span { lo, hi: lo + second.len() as u32 }; let identifier = leo_ast::Identifier { name: symbol, span: second_span, id: builder.next_id() }; - let function = leo_ast::Path::new(Vec::new(), identifier, false, None, span, builder.next_id()); - (function, Some(Symbol::intern(first))) + let program = leo_ast::Identifier { + name: Symbol::intern(first), + span: Span { lo: node.span.lo, hi: node.span.lo + first.len() as u32 + "aleo/".len() as u32 }, + id: builder.next_id(), + }; + leo_ast::Path::new( + Vec::new(), + identifier, + leo_ast::PathKind::External(program), + None, + span, + builder.next_id(), + ) } else { // It's a path. let mut components = path_to_parts(name, builder); let identifier = components.pop().unwrap(); - let function = leo_ast::Path::new(components, identifier, false, None, name.span, builder.next_id()); - (function, None) + leo_ast::Path::new( + components, + identifier, + leo_ast::PathKind::Relative, + None, + name.span, + builder.next_id(), + ) }; let mut const_arguments = Vec::new(); @@ -557,7 +616,7 @@ pub fn to_expression(node: &SyntaxNode<'_>, builder: &NodeBuilder, handler: &Han .collect::>>()?; } - leo_ast::CallExpression { function, const_arguments, arguments, program, span, id }.into() + leo_ast::CallExpression { function, const_arguments, arguments, program: None, span, id }.into() } ExpressionKind::Cast => { let [expression, _as, type_] = &node.children[..] else { @@ -574,7 +633,7 @@ pub fn to_expression(node: &SyntaxNode<'_>, builder: &NodeBuilder, handler: &Han // lossless tree just has the span of the entire path. let mut identifiers = path_to_parts(&node.children[0], builder); let identifier = identifiers.pop().unwrap(); - leo_ast::Path::new(identifiers, identifier, false, None, span, id).into() + leo_ast::Path::new(identifiers, identifier, leo_ast::PathKind::Relative, None, span, id).into() } ExpressionKind::Literal(literal_kind) => match literal_kind { LiteralKind::Address => { @@ -617,7 +676,7 @@ pub fn to_expression(node: &SyntaxNode<'_>, builder: &NodeBuilder, handler: &Han let text = node.children[0].text; // Parse the locator string in format "some_program.aleo/some_name" - if let Some((program_part, name_part)) = text.split_once(".aleo/") { + if let Some((program_part, name_part)) = text.split_once(".aleo::") { // Create the program identifier let program_name_symbol = Symbol::intern(program_part); let program_name_span = Span { lo: node.span.lo, hi: node.span.lo + program_part.len() as u32 }; @@ -899,9 +958,49 @@ pub fn to_expression(node: &SyntaxNode<'_>, builder: &NodeBuilder, handler: &Han } } - let mut identifiers = path_to_parts(name, builder); + let path = if let Some((first, second)) = name.text.split_once(".aleo::") { + // This is a locator. + let symbol = Symbol::intern(second); + let lo = node.span.lo + first.len() as u32 + ".aleo/".len() as u32; + let second_span = Span { lo, hi: lo + second.len() as u32 }; + let identifier = leo_ast::Identifier { name: symbol, span: second_span, id: builder.next_id() }; + let program = leo_ast::Identifier { + name: Symbol::intern(first), + span: Span { lo: node.span.lo, hi: node.span.lo + first.len() as u32 + "aleo/".len() as u32 }, + id: builder.next_id(), + }; + leo_ast::Path::new( + Vec::new(), + identifier, + leo_ast::PathKind::External(program), + None, + span, + builder.next_id(), + ) + } else { + // It's a path. + let mut components = path_to_parts(name, builder); + let identifier = components.pop().unwrap(); + leo_ast::Path::new( + components, + identifier, + leo_ast::PathKind::Relative, + None, + name.span, + builder.next_id(), + ) + }; + + /*let mut identifiers = path_to_parts(name, builder); let identifier = identifiers.pop().unwrap(); - let path = leo_ast::Path::new(identifiers, identifier, false, None, name.span, builder.next_id()); + let path = leo_ast::Path::new( + identifiers, + identifier, + leo_ast::PathKind::Relative, + None, + name.span, + builder.next_id(), + );*/ leo_ast::StructExpression { path, const_arguments, members, span, id }.into() } @@ -1379,6 +1478,7 @@ pub fn to_main(node: &SyntaxNode<'_>, builder: &NodeBuilder, handler: &Handler) modules: Default::default(), imports, stubs: Default::default(), + programs: Default::default(), program_scopes: std::iter::once((program_name_symbol, program_scope)).collect(), }) } diff --git a/compiler/passes/src/code_generation/expression.rs b/compiler/passes/src/code_generation/expression.rs index a72e8aae063..46b0f872df9 100644 --- a/compiler/passes/src/code_generation/expression.rs +++ b/compiler/passes/src/code_generation/expression.rs @@ -55,6 +55,7 @@ use snarkvm::{ }; use anyhow::bail; +use itertools::Itertools; use std::{borrow::Borrow, fmt::Write as _}; /// Implement the necessary methods to visit nodes in the AST. @@ -199,7 +200,7 @@ impl CodeGeneratingVisitor<'_> { let cast_instruction = format!( " cast {expression_operand} into {destination_register} as {};\n", - Self::visit_type(&input.type_) + self.visit_type(&input.type_) ); // Concatenate the instructions. @@ -224,7 +225,7 @@ impl CodeGeneratingVisitor<'_> { let Some(array_type @ Type::Array(..)) = self.state.type_table.get(&input.id) else { panic!("All types should be known at this phase of compilation"); }; - let array_type: String = Self::visit_type(&array_type); + let array_type: String = self.visit_type(&array_type); let array_instruction = format!(" cast {expression_operands} into {destination_register} as {array_type};\n"); @@ -392,7 +393,7 @@ impl CodeGeneratingVisitor<'_> { let Some(array_type @ Type::Array(..)) = self.state.type_table.get(&input.id) else { panic!("All types should be known at this phase of compilation"); }; - let array_type: String = Self::visit_type(&array_type); + let array_type: String = self.visit_type(&array_type); let array_instruction = format!(" cast {expression_operands} into {destination_register} as {array_type};\n"); @@ -602,7 +603,7 @@ impl CodeGeneratingVisitor<'_> { let instruction = format!( " serialize.{variant} {} ({}) into {destination_register} ({output_array_type});\n", arguments[0], - Self::visit_type(&input_type) + self.visit_type(&input_type) ); (destination_register, instruction) @@ -623,8 +624,8 @@ impl CodeGeneratingVisitor<'_> { let instruction = format!( " deserialize.{variant} {} ({}) into {destination_register} ({});\n", arguments[0], - Self::visit_type(&input_type), - Self::visit_type(&output_type) + self.visit_type(&input_type), + self.visit_type(&output_type) ); (destination_register, instruction) @@ -655,7 +656,7 @@ impl CodeGeneratingVisitor<'_> { fn visit_call(&mut self, input: &CallExpression) -> (String, String) { let caller_program = self.program_id.expect("Calls only appear within programs.").name.name; - let callee_program = input.program.unwrap_or(caller_program); + let callee_program = input.function.program().unwrap_or(caller_program); let func_symbol = self .state .symbol_table @@ -665,11 +666,11 @@ impl CodeGeneratingVisitor<'_> { // Need to determine the program the function originated from as well as if the function has a finalize block. let mut call_instruction = if caller_program != callee_program { // All external functions must be defined as stubs. - assert!( + /*assert!( self.program.stubs.get(&callee_program).is_some(), "Type checking guarantees that imported and stub programs are present." - ); - format!(" call {}.aleo/{}", callee_program, input.function) + );*/ + format!(" call {}.aleo/{}", callee_program, input.function.absolute_path().iter().format("::")) } else if func_symbol.function.variant.is_async() { format!(" async {}", self.current_function.unwrap().identifier) } else { @@ -771,7 +772,7 @@ impl CodeGeneratingVisitor<'_> { for i in 0..array_type.length.as_u32().expect("length should be known at this point") as usize { write!(&mut instruction, "{register}[{i}u32] ").unwrap(); } - writeln!(&mut instruction, "into {new_reg} as {};", Self::visit_type(typ)).unwrap(); + writeln!(&mut instruction, "into {new_reg} as {};", self.visit_type(typ)).unwrap(); (new_reg, instruction) } @@ -783,7 +784,7 @@ impl CodeGeneratingVisitor<'_> { .state .symbol_table .lookup_record(&location) - .or_else(|| self.state.symbol_table.lookup_struct(&comp_ty.path.absolute_path())) + .or_else(|| self.state.symbol_table.lookup_struct(&location)) .unwrap(); let mut instruction = " cast ".to_string(); for member in &comp.members { diff --git a/compiler/passes/src/code_generation/program.rs b/compiler/passes/src/code_generation/program.rs index 7f59cfc0acb..2ab204334bf 100644 --- a/compiler/passes/src/code_generation/program.rs +++ b/compiler/passes/src/code_generation/program.rs @@ -71,7 +71,7 @@ impl<'a> CodeGeneratingVisitor<'a> { let lookup = |name: &[Symbol]| { self.state .symbol_table - .lookup_struct(name) + .lookup_struct(&Location::new(this_program, name.to_vec())) .or_else(|| self.state.symbol_table.lookup_record(&Location::new(this_program, name.to_vec()))) }; @@ -159,7 +159,7 @@ impl<'a> CodeGeneratingVisitor<'a> { // Construct and append the record variables. for var in struct_.members.iter() { - writeln!(output_string, " {} as {};", var.identifier, Self::visit_type(&var.type_),).expect(EXPECT_STR); + writeln!(output_string, " {} as {};", var.identifier, self.visit_type(&var.type_),).expect(EXPECT_STR); } output_string @@ -193,7 +193,7 @@ impl<'a> CodeGeneratingVisitor<'a> { output_string, " {} as {}.{mode};", // todo: CAUTION private record variables only. var.identifier, - Self::visit_type(&var.type_) + self.visit_type(&var.type_) ) .expect(EXPECT_STR); } diff --git a/compiler/passes/src/code_generation/type_.rs b/compiler/passes/src/code_generation/type_.rs index 7665e49acbe..d7f23ec23d9 100644 --- a/compiler/passes/src/code_generation/type_.rs +++ b/compiler/passes/src/code_generation/type_.rs @@ -16,10 +16,10 @@ use super::*; -use leo_ast::{CompositeType, Location, Mode, Type}; +use leo_ast::{Location, Mode, Type}; impl CodeGeneratingVisitor<'_> { - pub fn visit_type(input: &Type) -> String { + pub fn visit_type(&self, input: &Type) -> String { match input { Type::Address | Type::Field @@ -30,8 +30,22 @@ impl CodeGeneratingVisitor<'_> { | Type::Future(..) | Type::Identifier(..) | Type::Integer(..) => format!("{input}"), - Type::Composite(CompositeType { path, .. }) => { - Self::legalize_path(&path.absolute_path()).expect("path format cannot be legalized at this point") + Type::Composite(composite) => { + let name = composite.path.absolute_path(); + let this_program_name = self.program_id.unwrap().name.name; + let program_name = composite.program.unwrap_or(this_program_name); + + if self.state.symbol_table.lookup_struct(&Location::new(program_name, name.to_vec())).is_some() { + let struct_name = + Self::legalize_path(&name).expect("path format cannot be legalized at this point"); + //if program_name == this_program_name { + struct_name.to_string() + //} else { + // format!("{program_name}.aleo/{struct_name}") + // } + } else { + panic!("Type checking prevents this.") + } } Type::Boolean => { // Leo calls this just `bool`, which isn't what we need. @@ -40,7 +54,7 @@ impl CodeGeneratingVisitor<'_> { Type::Array(array_type) => { format!( "[{}; {}u32]", - Self::visit_type(array_type.element_type()), + self.visit_type(array_type.element_type()), array_type.length.as_u32().expect("length should be known at this point") ) } @@ -81,9 +95,9 @@ impl CodeGeneratingVisitor<'_> { } if let Mode::None = visibility { - Self::visit_type(type_) + self.visit_type(type_) } else { - format!("{}.{visibility}", Self::visit_type(type_)) + format!("{}.{visibility}", self.visit_type(type_)) } } } diff --git a/compiler/passes/src/code_generation/visitor.rs b/compiler/passes/src/code_generation/visitor.rs index adcc9b47278..d2d4a644cb8 100644 --- a/compiler/passes/src/code_generation/visitor.rs +++ b/compiler/passes/src/code_generation/visitor.rs @@ -45,6 +45,7 @@ pub struct CodeGeneratingVisitor<'a> { /// The variant of the function we are currently traversing. pub variant: Option, /// A reference to program. This is needed to look up external programs. + #[allow(dead_code)] pub program: &'a Program, /// The program ID of the current program. pub program_id: Option, @@ -170,13 +171,20 @@ impl CodeGeneratingVisitor<'_> { } // === Case 3: Matches special form like `path::to::Name::[3, 4]` === - let re = regex::Regex::new(r#"^([a-zA-Z_][\w]*)(?:::\[.*?\])?$"#).unwrap(); + /*let re = regex::Regex::new(r#"^([a-zA-Z_][\w]*)(?:::\[.*?\])?$"#).unwrap(); if let Some(captures) = re.captures(&last) { let ident = captures.get(1)?.as_str(); // The produced name here will be of the form: `__AYMqiUeJeQN`. return Some(generate_hashed_name(path, ident)); + }*/ + if let Some(idx) = last.rfind("::[") { + // Extract everything before the array index + let base = &last[..idx]; + // Take the last identifier before the brackets + let base_ident = base.rsplit("::").next().filter(|s| is_legal_identifier(s)).unwrap_or("Indexed"); + return Some(generate_hashed_name(path, base_ident)); } // === Case 4: Matches special form like `path::to::Name?` (last always ends with `?`) === diff --git a/compiler/passes/src/common/symbol_table/mod.rs b/compiler/passes/src/common/symbol_table/mod.rs index 2308bef5edd..e7b67dc0ae0 100644 --- a/compiler/passes/src/common/symbol_table/mod.rs +++ b/compiler/passes/src/common/symbol_table/mod.rs @@ -37,7 +37,7 @@ pub struct SymbolTable { records: IndexMap, /// Structs indexed by a path. - structs: IndexMap, Composite>, + structs: IndexMap, /// Consts that have been successfully evaluated. global_consts: IndexMap, @@ -165,7 +165,7 @@ impl SymbolTable { } /// Iterator over all the structs (not records) in this program. - pub fn iter_structs(&self) -> impl Iterator, &Composite)> { + pub fn iter_structs(&self) -> impl Iterator { self.structs.iter() } @@ -180,8 +180,8 @@ impl SymbolTable { } /// Access the struct by this name if it exists. - pub fn lookup_struct(&self, path: &[Symbol]) -> Option<&Composite> { - self.structs.get(path) + pub fn lookup_struct(&self, location: &Location) -> Option<&Composite> { + self.structs.get(location) } /// Access the record at this location if it exists. @@ -357,19 +357,18 @@ impl SymbolTable { /// Insert a struct at this name. /// /// Since structs are indexed only by name, the program is used only to check shadowing. - pub fn insert_struct(&mut self, program: Symbol, path: &[Symbol], composite: Composite) -> Result<()> { - if let Some(old_composite) = self.structs.get(path) { + pub fn insert_struct(&mut self, location: Location, composite: Composite) -> Result<()> { + /*if let Some(old_composite) = self.structs.get(path) { if eq_struct(&composite, old_composite) { Ok(()) } else { Err(AstError::redefining_external_struct(path.iter().format("::"), old_composite.span).into()) } - } else { - let location = Location::new(program, path.to_vec()); - self.check_shadow_global(&location, composite.identifier.span)?; - self.structs.insert(path.to_vec(), composite); - Ok(()) - } + } else {*/ + self.check_shadow_global(&location, composite.identifier.span)?; + self.structs.insert(location, composite); + Ok(()) + // } } /// Insert a record at this location. @@ -414,7 +413,7 @@ impl SymbolTable { .get(location) .map(|f| f.function.identifier.span) .or_else(|| self.records.get(location).map(|r| r.identifier.span)) - .or_else(|| self.structs.get(&location.path).map(|s| s.identifier.span)) + .or_else(|| self.structs.get(location).map(|s| s.identifier.span)) .or_else(|| self.globals.get(location).map(|g| g.span)) .map_or_else(|| Ok(()), |prev_span| Err(Self::emit_shadow_error(*name, span, prev_span))) } @@ -469,7 +468,7 @@ impl SymbolTable { } } -fn eq_struct(new: &Composite, old: &Composite) -> bool { +/*fn eq_struct(new: &Composite, old: &Composite) -> bool { if new.members.len() != old.members.len() { return false; } @@ -478,4 +477,4 @@ fn eq_struct(new: &Composite, old: &Composite) -> bool { .iter() .zip(old.members.iter()) .all(|(member1, member2)| member1.name() == member2.name() && member1.type_.eq_flat_relaxed(&member2.type_)) -} +}*/ diff --git a/compiler/passes/src/common_subexpression_elimination/visitor.rs b/compiler/passes/src/common_subexpression_elimination/visitor.rs index e828450272b..ff80897caf4 100644 --- a/compiler/passes/src/common_subexpression_elimination/visitor.rs +++ b/compiler/passes/src/common_subexpression_elimination/visitor.rs @@ -16,7 +16,7 @@ use crate::CompilerState; -use leo_ast::{BinaryOperation, Expression, Identifier, LiteralVariant, Node as _, Path, UnaryOperation}; +use leo_ast::{BinaryOperation, Expression, Identifier, LiteralVariant, Node as _, Path, PathKind, UnaryOperation}; use leo_span::Symbol; use std::collections::HashMap; @@ -89,7 +89,7 @@ impl CommonSubexpressionEliminatingVisitor<'_> { let p = Path::new( Vec::new(), Identifier::new(*name, self.state.node_builder.next_id()), - true, + PathKind::Absolute, Some(vec![*name]), path.span(), self.state.node_builder.next_id(), diff --git a/compiler/passes/src/const_prop_unroll_and_morphing.rs b/compiler/passes/src/const_prop_unroll_and_morphing.rs index 53bb4907d5d..2896bf5fb42 100644 --- a/compiler/passes/src/const_prop_unroll_and_morphing.rs +++ b/compiler/passes/src/const_prop_unroll_and_morphing.rs @@ -45,11 +45,14 @@ impl Pass for ConstPropUnrollAndMorphing { let const_prop_output = ConstPropagation::do_pass((), state)?; + println!("before mono {}", state.ast.ast); let monomorphization_output = Monomorphization::do_pass((), state)?; + println!("after mono {}", state.ast.ast); // Clear the symbol table and create it again. This is important because after all the passes above run, the // program may have changed significantly (new functions may have been added, some functions may have been - // deleted, etc.) We do want to retain evaluated consts, so that const propagation can tell when it has evaluated a new one. + // deleted, etc.) We do want to retain evaluated consts, so that const propagation can tell when it has + // evaluated a new one. state.symbol_table.reset_but_consts(); SymbolTableCreation::do_pass((), state)?; diff --git a/compiler/passes/src/const_propagation/visitor.rs b/compiler/passes/src/const_propagation/visitor.rs index bdde0c4d4b7..2ab470242cd 100644 --- a/compiler/passes/src/const_propagation/visitor.rs +++ b/compiler/passes/src/const_propagation/visitor.rs @@ -16,7 +16,7 @@ use crate::CompilerState; -use leo_ast::{Expression, Node, NodeID, interpreter_value::Value}; +use leo_ast::{Expression, Location, Node, NodeID, interpreter_value::Value}; use leo_errors::StaticAnalyzerError; use leo_span::{Span, Symbol}; @@ -64,16 +64,16 @@ impl ConstPropagationVisitor<'_> { pub fn value_to_expression(&self, value: &Value, span: Span, id: NodeID) -> Option { let ty = self.state.type_table.get(&id)?; let symbol_table = &self.state.symbol_table; - let struct_lookup = |sym: &[Symbol]| { + let struct_lookup = |loc: &Location| { symbol_table - .lookup_struct(sym) + .lookup_struct(loc) .unwrap() .members .iter() .map(|mem| (mem.identifier.name, mem.type_.clone())) .collect() }; - value.to_expression(span, &self.state.node_builder, &ty, &struct_lookup) + value.to_expression(span, &self.program, &self.state.node_builder, &ty, &struct_lookup) } pub fn value_to_expression_node(&self, value: &Value, previous: &impl Node) -> Option { diff --git a/compiler/passes/src/flattening/ast.rs b/compiler/passes/src/flattening/ast.rs index 1007bf903ef..6356b48f47e 100644 --- a/compiler/passes/src/flattening/ast.rs +++ b/compiler/passes/src/flattening/ast.rs @@ -107,7 +107,7 @@ impl AstReconstructor for FlatteningVisitor<'_> { let if_true_type = self .state .symbol_table - .lookup_struct(&composite_path.absolute_path()) + .lookup_struct(&Location::new(program, composite_path.absolute_path())) .or_else(|| { self.state.symbol_table.lookup_record(&Location::new(program, composite_path.absolute_path())) }) diff --git a/compiler/passes/src/function_inlining/ast.rs b/compiler/passes/src/function_inlining/ast.rs index c2c87b1634f..c03b890160a 100644 --- a/compiler/passes/src/function_inlining/ast.rs +++ b/compiler/passes/src/function_inlining/ast.rs @@ -29,7 +29,7 @@ impl AstReconstructor for FunctionInliningVisitor<'_> { /* Expressions */ fn reconstruct_call(&mut self, input: CallExpression, _additional: &()) -> (Expression, Self::AdditionalOutput) { // Type checking guarantees that only functions local to the program scope can be inlined. - if input.program.is_some_and(|prog| prog != self.program) { + if input.function.program().is_some_and(|prog| prog != self.program) { return (input.into(), Default::default()); } diff --git a/compiler/passes/src/monomorphization/program.rs b/compiler/passes/src/monomorphization/program.rs index 086d29063bd..b7e9179c8d8 100644 --- a/compiler/passes/src/monomorphization/program.rs +++ b/compiler/passes/src/monomorphization/program.rs @@ -15,7 +15,18 @@ // along with the Leo library. If not, see . use super::MonomorphizationVisitor; -use leo_ast::{AstReconstructor, Module, Program, ProgramReconstructor, ProgramScope, Statement, Variant}; +use leo_ast::{ + AstReconstructor, + Composite, + ConstParameter, + Member, + Module, + Program, + ProgramReconstructor, + ProgramScope, + Statement, + Variant, +}; use leo_span::sym; impl ProgramReconstructor for MonomorphizationVisitor<'_> { @@ -33,7 +44,7 @@ impl ProgramReconstructor for MonomorphizationVisitor<'_> { // Perform monomorphization or other reconstruction logic. let reconstructed_struct = self.reconstruct_struct(r#struct); // Store the reconstructed struct for inclusion in the output scope. - self.reconstructed_structs.insert(struct_name.clone(), reconstructed_struct); + self.reconstructed_structs.insert((self.program, struct_name.clone()), reconstructed_struct); } } @@ -84,8 +95,6 @@ impl ProgramReconstructor for MonomorphizationVisitor<'_> { } } - // Get any - // Now reconstruct mappings and storage variables let mappings = input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(); @@ -129,9 +138,11 @@ impl ProgramReconstructor for MonomorphizationVisitor<'_> { structs: self .reconstructed_structs .iter() - .filter_map(|(path, c)| { + .filter_map(|((program, path), c)| { // only consider structs defined at program scope. The rest will be added to their parent module. - path.split_last().filter(|(_, rest)| rest.is_empty()).map(|(last, _)| (*last, c.clone())) + path.split_last() + .filter(|(_, rest)| rest.is_empty() && *program == input.program_id.name.name) + .map(|(last, _)| (*last, c.clone())) }) .collect(), mappings, @@ -149,7 +160,27 @@ impl ProgramReconstructor for MonomorphizationVisitor<'_> { } } + fn reconstruct_struct(&mut self, input: Composite) -> Composite { + Composite { + const_parameters: input + .const_parameters + .iter() + .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() }) + .collect(), + members: input + .members + .iter() + .map(|member| Member { type_: self.reconstruct_type(member.type_.clone()).0, ..member.clone() }) + .collect(), + id: self.state.node_builder.next_id(), // I thought this would reset the scopes :hmmmm: + ..input + } + } + fn reconstruct_program(&mut self, input: Program) -> Program { + let mut reconstructed_programs: indexmap::IndexMap = + input.programs.into_iter().map(|(id, program)| (id, self.reconstruct_program(program))).collect(); + // Populate `self.function_map` using the functions in the program scopes and the modules input .modules @@ -188,17 +219,56 @@ impl ProgramReconstructor for MonomorphizationVisitor<'_> { self.struct_map.insert(full_name, f); }); + let program_scopes = + input.program_scopes.into_iter().map(|(id, scope)| (id, self.reconstruct_program_scope(scope))).collect(); + + let modules = input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(); + + for (program_id, program) in reconstructed_programs.iter_mut() { + // Collect all top-level structs that belong to this program + let structs_for_program: Vec<_> = self + .reconstructed_structs + .iter() + .filter_map(|((owner_program, path), c)| { + path.split_last() + .filter(|(_, rest)| rest.is_empty() && *owner_program == *program_id) + .map(|(last, _)| (*last, c.clone())) + }) + .collect(); + + // Insert them directly into each program scope + for (_, scope) in program.program_scopes.iter_mut() { + for (struct_id, strukt) in &structs_for_program { + // Avoid duplicates (e.g., if it was already reconstructed earlier) + if !scope.structs.iter().any(|(id, _)| id == struct_id) { + scope.structs.push((*struct_id, strukt.clone())); + } + } + } + + // Handle structs that belong to modules under this program + for (_, module) in program.modules.iter_mut() { + let structs_for_module: Vec<_> = self + .reconstructed_structs + .iter() + .filter_map(|((owner_program, path), c)| { + path.split_last() + .filter(|(_, rest)| *owner_program == *program_id && *rest == module.path) + .map(|(last, _)| (*last, c.clone())) + }) + .collect(); + + for (struct_id, strukt) in structs_for_module { + if !module.structs.iter().any(|(id, _)| *id == struct_id) { + module.structs.push((struct_id, strukt)); + } + } + } + } + // Reconstruct prrogram scopes first then reconstruct the modules after `self.reconstructed_structs` // and `self.reconstructed_functions` have been populated. - Program { - program_scopes: input - .program_scopes - .into_iter() - .map(|(id, scope)| (id, self.reconstruct_program_scope(scope))) - .collect(), - modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(), - ..input - } + Program { programs: reconstructed_programs, program_scopes, modules, ..input } } fn reconstruct_module(&mut self, input: Module) -> Module { @@ -208,9 +278,9 @@ impl ProgramReconstructor for MonomorphizationVisitor<'_> { structs: self .reconstructed_structs .iter() - .filter_map(|(path, c)| path.split_last().map(|(last, rest)| (last, rest, c))) - .filter(|&(_, rest, _)| input.path == rest) - .map(|(last, _, c)| (*last, c.clone())) + .filter_map(|((program, path), c)| path.split_last().map(|(last, rest)| (program, last, rest, c))) + .filter(|&(program, _, rest, _)| input.path == rest && *program == input.program_name) + .map(|(_, last, _, c)| (*last, c.clone())) .collect(), functions: self diff --git a/compiler/passes/src/monomorphization/visitor.rs b/compiler/passes/src/monomorphization/visitor.rs index b986aefa5f1..890828842a1 100644 --- a/compiler/passes/src/monomorphization/visitor.rs +++ b/compiler/passes/src/monomorphization/visitor.rs @@ -44,7 +44,7 @@ pub struct MonomorphizationVisitor<'a> { /// the functions not the names of the monomorphized versions. pub monomorphized_functions: IndexSet>, /// A map of reconstructed functions in the current program scope. - pub reconstructed_structs: IndexMap, Composite>, + pub reconstructed_structs: IndexMap<(Symbol, Vec), Composite>, /// A set of all functions that have been monomorphized at least once. This keeps track of the _original_ names of /// the functions not the names of the monomorphized versions. pub monomorphized_structs: IndexSet>, @@ -79,20 +79,30 @@ impl MonomorphizationVisitor<'_> { // valid identifier in the user code. // // Later, we have to legalize these names because they are not valid Aleo identifiers. We do this in codegen. - let new_struct_path = path.clone().with_updated_last_symbol(leo_span::Symbol::intern(&format!( + /*let new_struct_path = path.clone().with_updated_last_symbol(leo_span::Symbol::intern(&format!( "{}::[{}]", path.identifier().name, const_arguments.iter().format(", ") - ))); + )));*/ + + let new_struct_path = { + let new_name = format!("{}::[{}]", path.identifier().name, const_arguments.iter().format(", ")); + let new_path = path.clone().with_updated_last_symbol(leo_span::Symbol::intern(&new_name)); + new_path + }; // Check if the new struct name is not already present in `reconstructed_structs`. This ensures that we do not // add a duplicate definition for the same struct. - if self.reconstructed_structs.get(&new_struct_path.absolute_path()).is_none() { + if self + .reconstructed_structs + .get(&(path.program().unwrap_or(self.program), new_struct_path.absolute_path())) + .is_none() + { let full_name = path.absolute_path(); // Look up the already reconstructed struct by name. let struct_ = self .reconstructed_structs - .get(&full_name) + .get(&(path.program().unwrap_or(self.program), full_name.clone())) .expect("Struct should already be reconstructed (post-order traversal)."); // Build mapping from const parameters to const argument values. @@ -123,12 +133,13 @@ impl MonomorphizationVisitor<'_> { struct_.id = self.state.node_builder.next_id(); // Keep track of the new struct in case other structs need it. - self.reconstructed_structs.insert(new_struct_path.absolute_path(), struct_); + self.reconstructed_structs + .insert((path.program().unwrap_or(self.program), new_struct_path.absolute_path()), struct_); // Now keep track of the struct we just monomorphized self.monomorphized_structs.insert(full_name); } - new_struct_path + dbg!(new_struct_path) } } diff --git a/compiler/passes/src/option_lowering/ast.rs b/compiler/passes/src/option_lowering/ast.rs index 4e9a97becb7..19b3d25ecdb 100644 --- a/compiler/passes/src/option_lowering/ast.rs +++ b/compiler/passes/src/option_lowering/ast.rs @@ -353,7 +353,7 @@ impl leo_ast::AstReconstructor for OptionLoweringVisitor<'_> { mut input: CallExpression, _additional: &Self::AdditionalInput, ) -> (Expression, Self::AdditionalOutput) { - let callee_program = input.program.unwrap_or(self.program); + let callee_program = input.function.program().unwrap_or(self.program); let func_symbol = self .state @@ -419,7 +419,7 @@ impl leo_ast::AstReconstructor for OptionLoweringVisitor<'_> { .state .symbol_table .lookup_record(&location) - .or_else(|| self.state.symbol_table.lookup_struct(&composite.path.absolute_path())) + .or_else(|| self.state.symbol_table.lookup_struct(&location)) .or_else(|| self.new_structs.get(&composite.path.identifier().name)) .expect("guaranteed by type checking"); diff --git a/compiler/passes/src/option_lowering/mod.rs b/compiler/passes/src/option_lowering/mod.rs index f6f652d6776..07fe3113fd7 100644 --- a/compiler/passes/src/option_lowering/mod.rs +++ b/compiler/passes/src/option_lowering/mod.rs @@ -95,6 +95,8 @@ impl Pass for OptionLowering { visitor.state.handler.last_err()?; visitor.state.ast = ast; + println!("{}", visitor.state.ast.ast); + // We need to recreate the symbol table and run type checking again because this pass may introduce new structs // and modify existing ones. visitor.state.symbol_table = SymbolTable::default(); diff --git a/compiler/passes/src/option_lowering/program.rs b/compiler/passes/src/option_lowering/program.rs index c69683d9b6a..838ad1c7231 100644 --- a/compiler/passes/src/option_lowering/program.rs +++ b/compiler/passes/src/option_lowering/program.rs @@ -54,6 +54,7 @@ impl ProgramReconstructor for OptionLoweringVisitor<'_> { .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1))) .collect(), stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(), + programs: input.programs.into_iter().map(|(id, program)| (id, self.reconstruct_program(program))).collect(), program_scopes: input .program_scopes .into_iter() diff --git a/compiler/passes/src/option_lowering/visitor.rs b/compiler/passes/src/option_lowering/visitor.rs index 0888f592286..b96e055d102 100644 --- a/compiler/passes/src/option_lowering/visitor.rs +++ b/compiler/passes/src/option_lowering/visitor.rs @@ -127,10 +127,10 @@ impl OptionLoweringVisitor<'_> { // Instead of relying on the symbol table (which does not get updated in this pass), we rely on the set of // reconstructed structs which is produced for all program scopes and all modules before doing anything else. let reconstructed_structs = &self.reconstructed_structs; - let struct_lookup = |sym: &[Symbol]| { + let struct_lookup = |loc: &Location| { reconstructed_structs - .get(sym) // check the new version of existing structs - .or_else(|| self.new_structs.get(sym.last().unwrap())) // check the newly produced structs + .get(&loc.path) // check the new version of existing structs + .or_else(|| self.new_structs.get(loc.path.last().unwrap())) // check the newly produced structs .expect("must exist by construction") .members .iter() @@ -138,8 +138,14 @@ impl OptionLoweringVisitor<'_> { .collect() }; - let zero_val_expr = - Expression::zero(&lowered_inner_type, Span::default(), &self.state.node_builder, &struct_lookup).expect(""); + let zero_val_expr = Expression::zero( + &lowered_inner_type, + Span::default(), + self.program, + &self.state.node_builder, + &struct_lookup, + ) + .expect(""); // Create or get an optional wrapper struct for `lowered_inner_type` let struct_name = self.insert_optional_wrapper_struct(&lowered_inner_type); diff --git a/compiler/passes/src/pass.rs b/compiler/passes/src/pass.rs index 7c3c2533c3d..3009717205a 100644 --- a/compiler/passes/src/pass.rs +++ b/compiler/passes/src/pass.rs @@ -19,7 +19,7 @@ use crate::{Assigner, SymbolTable, TypeTable}; use leo_ast::{Ast, CallGraph, NetworkName, NodeBuilder, StructGraph}; use leo_errors::{Handler, LeoWarning, Result}; -use std::collections::HashSet; +use std::{collections::HashSet, rc::Rc}; /// Contains data shared by many compiler passes. #[derive(Default)] @@ -31,7 +31,7 @@ pub struct CompilerState { /// Maps node IDs to types. pub type_table: TypeTable, /// Creates incrementing node IDs. - pub node_builder: NodeBuilder, + pub node_builder: Rc, /// Creates unique symbols and definitions. pub assigner: Assigner, /// Contains data about the variables and other entities in the program. diff --git a/compiler/passes/src/processing_async/ast.rs b/compiler/passes/src/processing_async/ast.rs index a38ba7748c1..4f03861a5cd 100644 --- a/compiler/passes/src/processing_async/ast.rs +++ b/compiler/passes/src/processing_async/ast.rs @@ -31,6 +31,7 @@ use leo_ast::{ Location, Node, Path, + PathKind, ProgramVisitor, Statement, TupleAccess, @@ -327,7 +328,7 @@ impl AstReconstructor for ProcessingAsyncVisitor<'_> { function: Path::new( vec![], make_identifier(self, finalize_fn_name), - true, + PathKind::Absolute, Some(vec![finalize_fn_name]), // the finalize function lives in the top level program scope Span::default(), self.state.node_builder.next_id(), diff --git a/compiler/passes/src/processing_script/ast.rs b/compiler/passes/src/processing_script/ast.rs index 59b9018a2c6..d1ef8a5cb64 100644 --- a/compiler/passes/src/processing_script/ast.rs +++ b/compiler/passes/src/processing_script/ast.rs @@ -26,7 +26,7 @@ impl AstReconstructor for ProcessingScriptVisitor<'_> { /* Expressions */ fn reconstruct_call(&mut self, input: CallExpression, _additional: &()) -> (Expression, Self::AdditionalOutput) { if !matches!(self.current_variant, Variant::Script) { - let callee_program = input.program.unwrap_or(self.program_name); + let callee_program = input.function.program().unwrap_or(self.program_name); let Some(func_symbol) = self .state diff --git a/compiler/passes/src/static_analysis/visitor.rs b/compiler/passes/src/static_analysis/visitor.rs index 8402307b364..e265b4cb995 100644 --- a/compiler/passes/src/static_analysis/visitor.rs +++ b/compiler/passes/src/static_analysis/visitor.rs @@ -135,7 +135,7 @@ impl AstVisitor for StaticAnalyzingVisitor<'_> { } // Look up the function and check if it is a non-async call. - let function_program = input.program.unwrap_or(self.current_program); + let function_program = input.function.program().unwrap_or(self.current_program); let func_symbol = self .state diff --git a/compiler/passes/src/static_single_assignment/expression.rs b/compiler/passes/src/static_single_assignment/expression.rs index 4892ee9e744..58337fd84a1 100644 --- a/compiler/passes/src/static_single_assignment/expression.rs +++ b/compiler/passes/src/static_single_assignment/expression.rs @@ -176,7 +176,7 @@ impl ExpressionConsumer for SsaFormingVisitor<'_> { .state .symbol_table .lookup_record(&Location::new(self.program, input.path.absolute_path())) - .or_else(|| self.state.symbol_table.lookup_struct(&input.path.absolute_path())) + .or_else(|| self.state.symbol_table.lookup_struct(&Location::new(self.program, input.path.absolute_path()))) .expect("Type checking guarantees this definition exists."); // Initialize the list of reordered members. diff --git a/compiler/passes/src/static_single_assignment/program.rs b/compiler/passes/src/static_single_assignment/program.rs index cb355753e2e..0294c133a85 100644 --- a/compiler/passes/src/static_single_assignment/program.rs +++ b/compiler/passes/src/static_single_assignment/program.rs @@ -151,6 +151,7 @@ impl ProgramConsumer for SsaFormingVisitor<'_> { .map(|(name, (import, span))| (name, (self.consume_program(import), span))) .collect(), stubs: input.stubs, + programs: input.programs.into_iter().map(|(name, prog)| (name, self.consume_program(prog))).collect(), program_scopes: input .program_scopes .into_iter() diff --git a/compiler/passes/src/storage_lowering/visitor.rs b/compiler/passes/src/storage_lowering/visitor.rs index fdf30d97ea5..d46bf91f5ab 100644 --- a/compiler/passes/src/storage_lowering/visitor.rs +++ b/compiler/passes/src/storage_lowering/visitor.rs @@ -144,16 +144,16 @@ impl StorageLoweringVisitor<'_> { pub fn zero(&self, ty: &Type) -> Expression { // zero value for element type (used as default in get_or_use) let symbol_table = &self.state.symbol_table; - let struct_lookup = |sym: &[Symbol]| { + let struct_lookup = |loc: &Location| { symbol_table - .lookup_struct(sym) + .lookup_struct(loc) .unwrap() .members .iter() .map(|mem| (mem.identifier.name, mem.type_.clone())) .collect() }; - Expression::zero(ty, Span::default(), &self.state.node_builder, &struct_lookup) + Expression::zero(ty, Span::default(), self.program, &self.state.node_builder, &struct_lookup) .expect("zero value generation failed") } } diff --git a/compiler/passes/src/symbol_table_creation/mod.rs b/compiler/passes/src/symbol_table_creation/mod.rs index e94e583d9b8..0e76d6e1e3d 100644 --- a/compiler/passes/src/symbol_table_creation/mod.rs +++ b/compiler/passes/src/symbol_table_creation/mod.rs @@ -14,7 +14,7 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{CompilerState, Pass, SymbolTable, VariableSymbol, VariableType}; +use crate::{CompilerState, Pass, VariableSymbol, VariableType}; use leo_ast::{ AstVisitor, @@ -36,9 +36,7 @@ use leo_ast::{ Variant, }; use leo_errors::Result; -use leo_span::{Span, Symbol}; - -use indexmap::IndexMap; +use leo_span::Symbol; /// A pass to fill the SymbolTable. /// @@ -53,13 +51,8 @@ impl Pass for SymbolTableCreation { fn do_pass(_input: Self::Input, state: &mut CompilerState) -> Result { let ast = std::mem::take(&mut state.ast); - let mut visitor = SymbolTableCreationVisitor { - state, - structs: IndexMap::new(), - program_name: Symbol::intern(""), - module: vec![], - is_stub: false, - }; + let mut visitor = + SymbolTableCreationVisitor { state, program_name: Symbol::intern(""), module: vec![], is_stub: false }; visitor.visit_program(ast.as_repr()); visitor.state.handler.last_err()?; visitor.state.ast = ast; @@ -76,8 +69,6 @@ struct SymbolTableCreationVisitor<'a> { module: Vec, /// Whether or not traversing stub. is_stub: bool, - /// The set of local structs that have been successfully visited. - structs: IndexMap, Span>, } impl SymbolTableCreationVisitor<'_> { @@ -143,19 +134,6 @@ impl ProgramVisitor for SymbolTableCreationVisitor<'_> { // Allow up to one local redefinition for each external struct. let full_name = self.module.iter().cloned().chain(std::iter::once(input.name())).collect::>(); - if !input.is_record { - if let Some(prev_span) = self.structs.get(&full_name) { - // The struct already existed - return self.state.handler.emit_err(SymbolTable::emit_shadow_error( - input.identifier.name, - input.identifier.span, - *prev_span, - )); - } - - self.structs.insert(full_name.clone(), input.identifier.span); - } - if input.is_record { // While records are not allowed in submodules, we stll use their full name in the records table. // We don't expect the full name to have more than a single Symbol though. @@ -165,8 +143,14 @@ impl ProgramVisitor for SymbolTableCreationVisitor<'_> { { self.state.handler.emit_err(err); } - } else if let Err(err) = self.state.symbol_table.insert_struct(self.program_name, &full_name, input.clone()) { - self.state.handler.emit_err(err); + } else { + let program_name = input.external.unwrap_or(self.program_name); + + if let Err(err) = + self.state.symbol_table.insert_struct(Location::new(program_name, full_name), input.clone()) + { + self.state.handler.emit_err(err); + } } } @@ -260,10 +244,13 @@ impl ProgramVisitor for SymbolTableCreationVisitor<'_> { { self.state.handler.emit_err(err); } - } else if let Err(err) = - self.state.symbol_table.insert_struct(self.program_name, &[input.name()], input.clone()) - { - self.state.handler.emit_err(err); + } else { + let program_name = input.external.unwrap_or(self.program_name); + if let Err(err) = + self.state.symbol_table.insert_struct(Location::new(program_name, vec![input.name()]), input.clone()) + { + self.state.handler.emit_err(err); + } } } } diff --git a/compiler/passes/src/type_checking/ast.rs b/compiler/passes/src/type_checking/ast.rs index 6b2fd8db52e..52c8ed3d48f 100644 --- a/compiler/passes/src/type_checking/ast.rs +++ b/compiler/passes/src/type_checking/ast.rs @@ -1072,7 +1072,7 @@ impl AstVisitor for TypeCheckingVisitor<'_> { } fn visit_call(&mut self, input: &CallExpression, expected: &Self::AdditionalInput) -> Self::Output { - let callee_program = input.program.or(self.scope_state.program_name).unwrap(); + let callee_program = input.function.program().or(self.scope_state.program_name).unwrap(); let callee_path = input.function.absolute_path(); @@ -1093,7 +1093,10 @@ impl AstVisitor for TypeCheckingVisitor<'_> { ), Variant::Transition | Variant::AsyncTransition if matches!(func.variant, Variant::Transition) - && input.program.is_none_or(|program| program == self.scope_state.program_name.unwrap()) => + && input + .function + .program() + .is_none_or(|program| program == self.scope_state.program_name.unwrap()) => { self.emit_err(TypeCheckerError::cannot_invoke_call_to_local_transition_function(input.span)) } @@ -1102,7 +1105,7 @@ impl AstVisitor for TypeCheckingVisitor<'_> { // Check that the call is not to an external `inline` function. if func.variant == Variant::Inline - && input.program.is_some_and(|program| program != self.scope_state.program_name.unwrap()) + && input.function.program().is_some_and(|program| program != self.scope_state.program_name.unwrap()) { self.emit_err(TypeCheckerError::cannot_call_external_inline_function(input.span)); } @@ -1182,8 +1185,19 @@ impl AstVisitor for TypeCheckingVisitor<'_> { let (mut input_futures, mut inferred_finalize_inputs) = (Vec::new(), Vec::new()); for (expected, argument) in func.input.iter().zip(input.arguments.iter()) { + let mut ty = expected.type_().clone(); + match &mut ty { + Type::Composite(comp) => { + comp.path = comp + .path + .clone() + .into_external(Identifier::new(callee_program, self.state.node_builder.next_id())); + } + _ => {} + } + // Get the type of the expression. If the type is not known, do not attempt to attempt any further inference. - let ty = self.visit_expression(argument, &Some(expected.type_().clone())); + let ty = self.visit_expression(argument, &Some(ty.clone())); if ty == Type::Err { return Type::Err; @@ -1359,7 +1373,9 @@ impl AstVisitor for TypeCheckingVisitor<'_> { } fn visit_struct_init(&mut self, input: &StructExpression, additional: &Self::AdditionalInput) -> Self::Output { - let struct_ = self.lookup_struct(self.scope_state.program_name, &input.path.absolute_path()).clone(); + let program = input.path.program().or(self.scope_state.program_name); + + let struct_ = self.lookup_struct(program, &input.path.absolute_path()).clone(); let Some(struct_) = struct_ else { self.emit_err(TypeCheckerError::unknown_sym("struct or record", input.path.clone(), input.path.span())); return Type::Err; @@ -1380,13 +1396,12 @@ impl AstVisitor for TypeCheckingVisitor<'_> { self.visit_expression(argument, &Some(expected.type_().clone())); } - // Note that it is sufficient for the `program` to be `None` as composite types can only be initialized - // in the program in which they are defined. let type_ = Type::Composite(CompositeType { path: input.path.clone(), const_arguments: input.const_arguments.clone(), - program: None, + program: input.path.program(), }); + self.maybe_assert_type(&type_, additional, input.path.span()); // Check number of struct members. diff --git a/compiler/passes/src/type_checking/mod.rs b/compiler/passes/src/type_checking/mod.rs index a88e7bd0c3c..e5bd3382cbe 100644 --- a/compiler/passes/src/type_checking/mod.rs +++ b/compiler/passes/src/type_checking/mod.rs @@ -89,7 +89,7 @@ impl Pass for TypeChecking { .symbol_table .iter_records() .map(|(loc, _)| loc.path.clone()) - .chain(state.symbol_table.iter_structs().map(|(name, _)| name.clone())) + .chain(state.symbol_table.iter_structs().map(|(loc, _)| loc.path.clone())) .collect(); let function_names = state.symbol_table.iter_functions().map(|(loc, _)| loc.clone()).collect(); diff --git a/compiler/passes/src/type_checking/program.rs b/compiler/passes/src/type_checking/program.rs index 836c82b3d15..7e79f0e6023 100644 --- a/compiler/passes/src/type_checking/program.rs +++ b/compiler/passes/src/type_checking/program.rs @@ -41,6 +41,9 @@ impl ProgramVisitor for TypeCheckingVisitor<'_> { }); self.scope_state.is_stub = false; + // Typecheck the modules. + input.programs.values().for_each(|program| self.visit_program(program)); + // Typecheck the modules. input.modules.values().for_each(|module| self.visit_module(module)); diff --git a/compiler/passes/src/type_checking/visitor.rs b/compiler/passes/src/type_checking/visitor.rs index 28a15100d4e..6d2794b0e7a 100644 --- a/compiler/passes/src/type_checking/visitor.rs +++ b/compiler/passes/src/type_checking/visitor.rs @@ -1458,7 +1458,7 @@ impl TypeCheckingVisitor<'_> { } // Check that the type of the input parameter does not contain an optional. - if self.contains_optional_type(table_type) + /*if self.contains_optional_type(table_type) && matches!(function.variant, Variant::Transition | Variant::AsyncTransition | Variant::Function) { self.emit_err(TypeCheckerError::function_cannot_take_option_as_input( @@ -1466,7 +1466,7 @@ impl TypeCheckingVisitor<'_> { table_type, input.span(), )) - } + }*/ // Make sure only transitions can take a record as an input. if let Type::Composite(struct_) = table_type { @@ -1622,7 +1622,10 @@ impl TypeCheckingVisitor<'_> { pub fn lookup_struct(&mut self, program: Option, name: &[Symbol]) -> Option { let record_comp = program.and_then(|prog| self.state.symbol_table.lookup_record(&Location::new(prog, name.to_vec()))); - let comp = record_comp.or_else(|| self.state.symbol_table.lookup_struct(name)); + let comp = program.and_then(|prog| { + record_comp.or_else(|| self.state.symbol_table.lookup_struct(&Location::new(prog, name.to_vec()))) + }); + // Record the usage. if let Some(s) = comp { // If it's a struct or internal record, mark it used. diff --git a/compiler/passes/src/write_transforming/visitor.rs b/compiler/passes/src/write_transforming/visitor.rs index 98e4ad67725..4cd49556838 100644 --- a/compiler/passes/src/write_transforming/visitor.rs +++ b/compiler/passes/src/write_transforming/visitor.rs @@ -309,7 +309,10 @@ impl WriteTransformingFiller<'_> { .0 .state .symbol_table - .lookup_struct(&comp.path.absolute_path()) + .lookup_struct(&Location::new( + comp.program.unwrap_or(self.0.program), + comp.path.absolute_path(), + )) .or_else(|| { self.0.state.symbol_table.lookup_record(&Location::new( comp.program.unwrap_or(self.0.program), diff --git a/interpreter/src/test_interpreter.rs b/interpreter/src/test_interpreter.rs index b90711f368b..08fa1906b9f 100644 --- a/interpreter/src/test_interpreter.rs +++ b/interpreter/src/test_interpreter.rs @@ -44,9 +44,11 @@ fn whole_compile(source: &str, handler: &Handler, import_stubs: IndexMap for CompilerOptions { fn from(options: BuildOptions) -> Self { @@ -136,14 +136,17 @@ fn handle_build(command: &LeoBuild, context: Context) -> Result< = IndexMap::new(); + let stubs: IndexMap = IndexMap::new(); + let mut programs: IndexMap = IndexMap::new(); for program in package.programs.iter() { - let (bytecode, build_path) = match &program.data { + let (ast_program, bytecode, build_path) = match &program.data { leo_package::ProgramData::Bytecode(bytecode) => { // This was a network dependency or local .aleo dependency, and we have its bytecode. - (bytecode.clone(), imports_directory.join(format!("{}.aleo", program.name))) + (None, Some(bytecode.clone()), imports_directory.join(format!("{}.aleo", program.name))) } leo_package::ProgramData::SourcePath { directory, source } => { // This is a local dependency, so we must compile it. @@ -154,6 +157,21 @@ fn handle_build(command: &LeoBuild, context: Context) -> Result< Result< leo_disassembler::disassemble_from_str::(program.name, &bytecode), - NetworkName::TestnetV0 => leo_disassembler::disassemble_from_str::(program.name, &bytecode), - NetworkName::CanaryV0 => leo_disassembler::disassemble_from_str::(program.name, &bytecode), - }?; - stubs.insert(program.name, stub); + // let stub = match network { + // NetworkName::MainnetV0 => leo_disassembler::disassemble_from_str::(program.name, &bytecode), + // NetworkName::TestnetV0 => leo_disassembler::disassemble_from_str::(program.name, &bytecode), + // NetworkName::CanaryV0 => leo_disassembler::disassemble_from_str::(program.name, &bytecode), + // }?; + // stubs.insert(program.name, stub); + if let Some(ast_program) = &ast_program { + programs.insert(program.name, ast_program.clone()); + } } // SnarkVM expects to find a `program.json` file in the build directory, so make @@ -207,8 +233,10 @@ fn compile_leo_source_directory( is_test: bool, output_path: &Path, handler: &Handler, + node_builder: &Rc, options: BuildOptions, stubs: IndexMap, + programs: IndexMap, network: NetworkName, ) -> Result { // Create a new instance of the Leo compiler. @@ -216,9 +244,11 @@ fn compile_leo_source_directory( Some(program_name.to_string()), is_test, handler.clone(), + Rc::clone(node_builder), output_path.to_path_buf(), Some(options.into()), stubs, + programs, network, ); @@ -251,3 +281,35 @@ fn compile_leo_source_directory( tracing::info!("✅ Compiled '{program_name}.aleo' into Aleo instructions."); Ok(bytecode) } + +/// Compiles a Leo file. Writes and returns the compiled bytecode. +#[allow(clippy::too_many_arguments)] +fn parse_leo_source_directory( + entry_file_path: &Path, + source_directory: &Path, + program_name: Symbol, + is_test: bool, + output_path: &Path, + handler: &Handler, + node_builder: &Rc, + options: BuildOptions, + stubs: IndexMap, + programs: IndexMap, + network: NetworkName, +) -> Result { + // Create a new instance of the Leo compiler. + let mut compiler = Compiler::new( + Some(program_name.to_string()), + is_test, + handler.clone(), + Rc::clone(node_builder), + output_path.to_path_buf(), + Some(options.into()), + stubs, + programs, + network, + ); + + // Parse the Leo program into an AST. + compiler.parse_from_directory(entry_file_path, source_directory) +} diff --git a/tests/expectations/compiler/const_generics/external_generic_struct.out b/tests/expectations/compiler/const_generics/external_generic_struct.out index 9e8f937d508..bddce8155ed 100644 --- a/tests/expectations/compiler/const_generics/external_generic_struct.out +++ b/tests/expectations/compiler/const_generics/external_generic_struct.out @@ -17,13 +17,24 @@ function main: cast r3 into r4 as Bar__5Qh5JlRc8cY; output r4 as Bar__5Qh5JlRc8cY.private; // --- Next Program --- // -import child.aleo; program parent.aleo; +struct Bar__An8FVFF3i8s: + arr as [u32; 2u32]; + struct Bar__DI7sPAg0NJ0: arr as [u32; 5u32]; struct Bar__5Qh5JlRc8cY: arr as [u32; 3u32]; +struct Bar__IEGwqtYsUsS: + arr as [u32; 4u32]; + function main: + cast 10u32 20u32 into r0 as [u32; 2u32]; + cast r0 into r1 as Bar__An8FVFF3i8s; + is.eq r1.arr[0u32] 10u32 into r2; + assert.eq r2 true; + is.eq r1.arr[1u32] 20u32 into r3; + assert.eq r3 true; diff --git a/tests/tests/compiler/const_generics/external_generic_struct.leo b/tests/tests/compiler/const_generics/external_generic_struct.leo index 44f2e3f81f4..fb5f91576b4 100644 --- a/tests/tests/compiler/const_generics/external_generic_struct.leo +++ b/tests/tests/compiler/const_generics/external_generic_struct.leo @@ -18,5 +18,14 @@ program child.aleo { import child.aleo; program parent.aleo { transition main() { + // Construct a Bar::[5] from child.aleo + let bar2 = child.aleo::Bar::[2] { + arr: [10u32, 20u32], + }; + + // Since child::main always builds Bar::[3] from c.arr[1], + // and c.arr = [b.arr[0]; 4], we expect all elements == b.arr[0]. + assert(bar2.arr[0] == 10u32); + assert(bar2.arr[1] == 20u32); } }