Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 207 additions & 14 deletions compiler/passes/src/processing_async/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,29 @@
use super::ProcessingAsyncVisitor;
use crate::{CompilerState, Replacer};
use indexmap::{IndexMap, IndexSet};
use itertools::Itertools;
use leo_ast::{
AstReconstructor,
AstVisitor,
AsyncExpression,
Block,
CallExpression,
Composite,
Expression,
Function,
Identifier,
Input,
IterationStatement,
Location,
Member,
MemberAccess,
Mode,
Node,
Path,
ProgramVisitor,
Statement,
StructExpression,
StructVariableInitializer,
TupleAccess,
TupleExpression,
TupleType,
Expand Down Expand Up @@ -77,6 +84,166 @@ impl AstVisitor for SymbolAccessCollector<'_> {

impl ProgramVisitor for SymbolAccessCollector<'_> {}

/// Bundle inputs into structs when they exceed max_inputs.
/// Returns: (new_inputs, new_arguments, synthetic_structs, replacements_for_bundled)
/// where replacements_for_bundled maps (Symbol, Option<usize>) to the path to access the field in the bundled struct.
fn bundle_inputs_into_structs(
inputs: Vec<Input>,
arguments: Vec<Expression>,
input_metadata: Vec<(Symbol, Option<usize>)>,
max_inputs: usize,
function_name: Symbol,
node_builder: &mut leo_ast::NodeBuilder,
assigner: &mut crate::Assigner,
) -> (Vec<Input>, Vec<Expression>, Vec<Composite>, IndexMap<(Symbol, Option<usize>), Expression>) {
if inputs.len() <= max_inputs {
return (inputs, arguments, vec![], IndexMap::new());
}

let mut new_inputs = Vec::new();
let mut new_arguments = Vec::new();
let mut synthetic_structs = Vec::new();
let mut replacements = IndexMap::new();

let total_inputs = inputs.len();
let bundle_capacity = max_inputs; // 16 is max
let num_bundles = (total_inputs - max_inputs + bundle_capacity - 2) / (bundle_capacity - 1);
let unbundled_count = max_inputs - num_bundles;

// Keep the first unbundled_count inputs as-is
for (input, arg) in inputs.iter().zip(arguments.iter()).take(unbundled_count) {
new_inputs.push(input.clone());
new_arguments.push(arg.clone());
}

// Bundle remaining inputs into struct(s)
let remaining_inputs: Vec<_> = inputs.into_iter().skip(unbundled_count).collect();
let remaining_arguments: Vec<_> = arguments.into_iter().skip(unbundled_count).collect();
let remaining_metadata: Vec<_> = input_metadata.into_iter().skip(unbundled_count).collect();

for ((chunk_inputs, chunk_args), chunk_metadata) in remaining_inputs
.chunks(bundle_capacity)
.zip(remaining_arguments.chunks(bundle_capacity))
.zip(remaining_metadata.chunks(bundle_capacity))
{
// Generate a unique struct name
let sanitized_fn_name = Symbol::intern(&function_name.to_string().replace('$', "_"));
let struct_name = assigner.unique_symbol(sanitized_fn_name, "_bundle");

let struct_identifier = Identifier {
name: struct_name,
span: Span::default(),
id: node_builder.next_id(),
};

// Create structt members from inputs
let members: Vec<Member> = chunk_inputs
.iter()
.enumerate()
.map(|(i, input)| Member {
mode: Mode::None,
identifier: Identifier {
name: Symbol::intern(&format!("field_{}", i)),
span: Span::default(),
id: node_builder.next_id(),
},
type_: input.type_.clone(),
span: Span::default(),
id: node_builder.next_id(),
})
.collect();

// Create the synthetic struct definition
let composite = Composite {
identifier: struct_identifier,
const_parameters: vec![],
members: members.clone(),
external: None,
is_record: false,
span: Span::default(),
id: node_builder.next_id(),
};

synthetic_structs.push(composite);

// Create an input parameter for this bundled struct
// Note: program should be None for local structs - the program context comes from scope during type checking
let composite_type = Type::Composite(leo_ast::CompositeType {
path: Path::from(struct_identifier).into_absolute(),
const_arguments: vec![],
program: None,
});

let param_name = assigner.unique_symbol(struct_name, "_param");

let bundle_input = Input {
identifier: Identifier {
name: param_name,
span: Span::default(),
id: node_builder.next_id(),
},
mode: Mode::None,
type_: composite_type,
span: Span::default(),
id: node_builder.next_id(),
};

new_inputs.push(bundle_input);

// Create a struct initialization expression as the argument
let struct_members: Vec<StructVariableInitializer> = chunk_inputs
.iter()
.zip(chunk_args.iter())
.enumerate()
.map(|(i, (_, arg))| StructVariableInitializer {
identifier: Identifier {
name: Symbol::intern(&format!("field_{}", i)),
span: Span::default(),
id: node_builder.next_id(),
},
expression: Some(arg.clone()),
span: Span::default(),
id: node_builder.next_id(),
})
.collect();

let struct_init = StructExpression {
path: Path::from(struct_identifier).into_absolute(),
const_arguments: vec![],
members: struct_members,
span: Span::default(),
id: node_builder.next_id(),
};

new_arguments.push(struct_init.into());

// Track replacements, the original input names should now be accessed via struct.field_N
for (i, (original_symbol, original_index)) in chunk_metadata.iter().enumerate() {
// Create a member access expression: param_name.field_i
let member_access = MemberAccess {
inner: Path::from(Identifier {
name: param_name,
span: Span::default(),
id: node_builder.next_id(),
})
.into(),
name: Identifier {
name: Symbol::intern(&format!("field_{}", i)),
span: Span::default(),
id: node_builder.next_id(),
},
span: Span::default(),
id: node_builder.next_id(),
}
.into();

replacements.insert((*original_symbol, *original_index), member_access);
}
}

(new_inputs, new_arguments, synthetic_structs, replacements)
}

impl AstReconstructor for ProcessingAsyncVisitor<'_> {
type AdditionalInput = ();
type AdditionalOutput = ();
Expand Down Expand Up @@ -245,7 +412,7 @@ impl AstReconstructor for ProcessingAsyncVisitor<'_> {
};

// Step 3: Resolve symbol accesses into inputs and call arguments
let (inputs, arguments): (Vec<_>, Vec<_>) = access_collector
let inputs_and_args_with_metadata: Vec<_> = access_collector
.symbol_accesses
.iter()
.filter_map(|(path, index)| {
Expand All @@ -265,12 +432,42 @@ impl AstReconstructor for ProcessingAsyncVisitor<'_> {

// All other variables become parameters to the async function being built.
let var = self.state.symbol_table.lookup_local(local_var_name)?;
Some(make_inputs_and_arguments(self, local_var_name, &var.type_, *index))
let inputs_and_args = make_inputs_and_arguments(self, local_var_name, &var.type_, *index);

// For each (Input, Expression) pair, attach metadata (symbol, index)
Some(
inputs_and_args
.into_iter()
.map(|(input, arg)| (input, arg, (local_var_name, *index)))
.collect::<Vec<_>>()
)
})
.flatten()
.unzip();
.collect();

// Separate into parallel vectors
let (inputs, arguments, input_metadata): (Vec<_>, Vec<_>, Vec<_>) =
inputs_and_args_with_metadata.into_iter()
.map(|(input, arg, metadata)| (input, arg, metadata))
.multiunzip();

// Step 4: Replacement logic used to patch the async block
// Step 4: Bundle inputs if necessary
let (final_inputs, final_arguments, synthetic_structs, bundle_replacements) = bundle_inputs_into_structs(
inputs,
arguments,
input_metadata,
self.max_inputs,
finalize_fn_name,
&mut self.state.node_builder,
&mut self.state.assigner,
);

// If there are bundle replacements, merge them into the main replacements map
for (key, expr) in bundle_replacements {
replacements.insert(key, expr);
}

// Step 5: Reconstruct the block with replaced references
let replace_expr = |expr: &Expression| -> Expression {
match expr {
Expression::Path(path) => {
Expand All @@ -292,17 +489,13 @@ impl AstReconstructor for ProcessingAsyncVisitor<'_> {
}
};

// Step 5: Reconstruct the block with replaced references
let mut replacer = Replacer::new(replace_expr, true /* refresh IDs */, self.state);
let new_block = replacer.reconstruct_block(input.block.clone()).0;

// Ensure we're not trying to capture too many variables
if inputs.len() > self.max_inputs {
self.state.handler.emit_err(leo_errors::StaticAnalyzerError::async_block_capturing_too_many_vars(
inputs.len(),
self.max_inputs,
input.span,
));
// Register synthetic structs
for composite in synthetic_structs {
let struct_name = composite.name();
self.synthetic_structs.push((struct_name, composite));
}

// Step 6: Define the new async function
Expand All @@ -311,7 +504,7 @@ impl AstReconstructor for ProcessingAsyncVisitor<'_> {
variant: Variant::AsyncFunction,
identifier: make_identifier(self, finalize_fn_name),
const_parameters: vec![],
input: inputs,
input: final_inputs.clone(),
output: vec![], // `async function`s can't have returns
output_type: Type::Unit, // Always the case for `async function`s
block: new_block,
Expand All @@ -333,7 +526,7 @@ impl AstReconstructor for ProcessingAsyncVisitor<'_> {
self.state.node_builder.next_id(),
),
const_arguments: vec![],
arguments,
arguments: final_arguments,
program: Some(self.current_program),
span: input.span,
id: self.state.node_builder.next_id(),
Expand Down
1 change: 1 addition & 0 deletions compiler/passes/src/processing_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl Pass for ProcessingAsync {
current_program: Symbol::intern(""),
current_function: Symbol::intern(""),
new_async_functions: Vec::new(),
synthetic_structs: Vec::new(),
modified: false,
};
ast.ast = visitor.reconstruct_program(ast.ast);
Expand Down
8 changes: 7 additions & 1 deletion compiler/passes/src/processing_async/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,15 @@ impl ProgramReconstructor for ProcessingAsyncVisitor<'_> {
// Now append all newly created `async` functions. This ensures transition functions still show up before all other functions.
reconstructed_functions.append(&mut self.new_async_functions);

// Reconstruct existing structs and append synthetic structs created for bundling captured variables
let mut all_structs: Vec<_> =
input.structs.into_iter().map(|(id, def)| (id, self.reconstruct_struct(def))).collect();

all_structs.append(&mut self.synthetic_structs);

ProgramScope {
program_id: input.program_id,
structs: input.structs.into_iter().map(|(id, def)| (id, self.reconstruct_struct(def))).collect(),
structs: all_structs,
mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
functions: reconstructed_functions,
constructor: input.constructor,
Expand Down
4 changes: 3 additions & 1 deletion compiler/passes/src/processing_async/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

use crate::CompilerState;

use leo_ast::{Function, NodeID};
use leo_ast::{Composite, Function, NodeID};
use leo_span::Symbol;

pub struct ProcessingAsyncVisitor<'a> {
Expand All @@ -30,6 +30,8 @@ pub struct ProcessingAsyncVisitor<'a> {
pub current_function: Symbol,
/// A map of reconstructed functions in the current program scope.
pub new_async_functions: Vec<(Symbol, Function)>,
/// Synthetic structs created to bundle captured variables when an async block captures more than MAX_INPUTS.
pub synthetic_structs: Vec<(Symbol, Composite)>,
/// Indicates whether this pass actually processed any async blocks.
pub modified: bool,
}
Expand Down
Loading