Skip to content

Commit

Permalink
fix(sea): mem2reg to treat block input references as alias (#6452)
Browse files Browse the repository at this point in the history
  • Loading branch information
aakoshh authored Nov 5, 2024
1 parent 8b5afec commit 5310064
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 25 deletions.
8 changes: 4 additions & 4 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,19 @@ pub(crate) fn optimize_into_acir(
.run_pass(Ssa::resolve_is_unconstrained, "After Resolving IsUnconstrained:")
.run_pass(|ssa| ssa.inline_functions(options.inliner_aggressiveness), "After Inlining:")
// Run mem2reg with the CFG separated into blocks
.run_pass(Ssa::mem2reg, "After Mem2Reg:")
.run_pass(Ssa::simplify_cfg, "After Simplifying:")
.run_pass(Ssa::mem2reg, "After Mem2Reg (1st):")
.run_pass(Ssa::simplify_cfg, "After Simplifying (1st):")
.run_pass(Ssa::as_slice_optimization, "After `as_slice` optimization")
.try_run_pass(
Ssa::evaluate_static_assert_and_assert_constant,
"After `static_assert` and `assert_constant`:",
)?
.try_run_pass(Ssa::unroll_loops_iteratively, "After Unrolling:")?
.run_pass(Ssa::simplify_cfg, "After Simplifying:")
.run_pass(Ssa::simplify_cfg, "After Simplifying (2nd):")
.run_pass(Ssa::flatten_cfg, "After Flattening:")
.run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:")
// Run mem2reg once more with the flattened CFG to catch any remaining loads/stores
.run_pass(Ssa::mem2reg, "After Mem2Reg:")
.run_pass(Ssa::mem2reg, "After Mem2Reg (2nd):")
// Run the inlining pass again to handle functions with `InlineType::NoPredicates`.
// Before flattening is run, we treat functions marked with the `InlineType::NoPredicates` as an entry point.
// This pass must come immediately following `mem2reg` as the succeeding passes
Expand Down
56 changes: 39 additions & 17 deletions compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,37 +212,32 @@ impl<'f> PerFunctionContext<'f> {
all_terminator_values: &HashSet<ValueId>,
per_func_block_params: &HashSet<ValueId>,
) -> bool {
let func_params = self.inserter.function.parameters();
let reference_parameters = func_params
.iter()
.filter(|param| self.inserter.function.dfg.value_is_reference(**param))
.collect::<BTreeSet<_>>();
let reference_parameters = self.reference_parameters();

let mut store_alias_used = false;
if let Some(expression) = block.expressions.get(store_address) {
if let Some(aliases) = block.aliases.get(expression) {
let allocation_aliases_parameter =
aliases.any(|alias| reference_parameters.contains(&alias));
if allocation_aliases_parameter == Some(true) {
store_alias_used = true;
return true;
}

let allocation_aliases_parameter =
aliases.any(|alias| per_func_block_params.contains(&alias));
if allocation_aliases_parameter == Some(true) {
store_alias_used = true;
return true;
}

let allocation_aliases_parameter =
aliases.any(|alias| self.calls_reference_input.contains(&alias));
if allocation_aliases_parameter == Some(true) {
store_alias_used = true;
return true;
}

let allocation_aliases_parameter =
aliases.any(|alias| all_terminator_values.contains(&alias));
if allocation_aliases_parameter == Some(true) {
store_alias_used = true;
return true;
}

let allocation_aliases_parameter = aliases.any(|alias| {
Expand All @@ -252,14 +247,25 @@ impl<'f> PerFunctionContext<'f> {
false
}
});

if allocation_aliases_parameter == Some(true) {
store_alias_used = true;
return true;
}
}
}

store_alias_used
false
}

/// Collect the input parameters of the function which are of reference type.
/// All references are mutable, so these inputs are shared with the function caller
/// and thus stores should not be eliminated, even if the blocks in this function
/// don't use them anywhere.
fn reference_parameters(&self) -> BTreeSet<ValueId> {
let parameters = self.inserter.function.parameters().iter();
parameters
.filter(|param| self.inserter.function.dfg.value_is_reference(**param))
.copied()
.collect()
}

fn recursively_add_values(&self, value: ValueId, set: &mut HashSet<ValueId>) {
Expand Down Expand Up @@ -300,6 +306,8 @@ impl<'f> PerFunctionContext<'f> {
fn analyze_block(&mut self, block: BasicBlockId, mut references: Block) {
let instructions = self.inserter.function.dfg[block].take_instructions();

self.add_aliases_for_reference_parameters(block, &mut references);

for instruction in instructions {
self.analyze_instruction(block, &mut references, instruction);
}
Expand All @@ -316,13 +324,25 @@ impl<'f> PerFunctionContext<'f> {
self.blocks.insert(block, references);
}

/// Add a self-alias for input parameters, similarly to how a newly allocated reference has
/// one alias already - itself. If we don't, then the checks using `reference_parameters()`
/// might find the default (empty) aliases and think the an input reference can be removed.
fn add_aliases_for_reference_parameters(&self, block: BasicBlockId, references: &mut Block) {
let dfg = &self.inserter.function.dfg;
let params = dfg.block_parameters(block);
let params = params.iter().filter(|p| dfg.value_is_reference(**p));

for param in params {
let expression =
references.expressions.entry(*param).or_insert(Expression::Other(*param));
references.aliases.entry(expression.clone()).or_insert_with(|| AliasSet::known(*param));
}
}

/// Add all instructions in `last_stores` to `self.instructions_to_remove` which do not
/// possibly alias any parameters of the given function.
fn remove_stores_that_do_not_alias_parameters(&mut self, references: &Block) {
let parameters = self.inserter.function.parameters().iter();
let reference_parameters = parameters
.filter(|param| self.inserter.function.dfg.value_is_reference(**param))
.collect::<BTreeSet<_>>();
let reference_parameters = self.reference_parameters();

for (allocation, instruction) in &references.last_stores {
if let Some(expression) = references.expressions.get(allocation) {
Expand Down Expand Up @@ -466,6 +486,8 @@ impl<'f> PerFunctionContext<'f> {
}
}

/// If `array` is an array constant that contains reference types, then insert each element
/// as a potential alias to the array itself.
fn check_array_aliasing(&self, references: &mut Block, array: ValueId) {
if let Some((elements, typ)) = self.inserter.function.dfg.get_array_constant(array) {
if Self::contains_references(&typ) {
Expand Down
18 changes: 14 additions & 4 deletions test_programs/execution_success/references/src/main.nr
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
fn main(mut x: Field) {
add1(&mut x);
assert(x == 3);

let mut s = S { y: x };
s.add2();
assert(s.y == 5);
Expand All @@ -21,18 +20,16 @@ fn main(mut x: Field) {
let mut c = C { foo: 0, bar: &mut C2 { array: &mut [1, 2] } };
*c.bar.array = [3, 4];
assert(*c.bar.array == [3, 4]);

regression_1887();
regression_2054();
regression_2030();
regression_2255();

regression_6443();
assert(x == 3);
regression_2218_if_inner_if(x, 10);
regression_2218_if_inner_else(20, x);
regression_2218_else(x, 3);
regression_2218_loop(x, 10);

regression_2560(s_ref);
}

Expand Down Expand Up @@ -106,6 +103,7 @@ fn regression_2030() {
let _ = *array[0];
*array[0] = 1;
}

// The `mut x: &mut ...` caught a bug handling lvalues where a double-dereference would occur internally
// in one step rather than being tracked by two separate steps. This lead to assigning the 1 value to the
// incorrect outer `mut` reference rather than the correct `&mut` reference.
Expand All @@ -119,6 +117,18 @@ fn regression_2255_helper(mut x: &mut Field) {
*x = 1;
}

// Similar to `regression_2255` but without the double-dereferencing.
// The test checks that `mem2reg` does not eliminate storing to a reference passed as a parameter.
fn regression_6443() {
let x = &mut 0;
regression_6443_helper(x);
assert(*x == 1);
}

fn regression_6443_helper(x: &mut Field) {
*x = 1;
}

fn regression_2218(x: Field, y: Field) -> Field {
let q = &mut &mut 0;
let q1 = *q;
Expand Down

0 comments on commit 5310064

Please sign in to comment.