From 924f92b58790b76c3d552d858356ecb8685a0313 Mon Sep 17 00:00:00 2001 From: guipublic Date: Thu, 10 Oct 2024 11:32:50 +0000 Subject: [PATCH 1/3] try to use big-add gates --- .../opcodes/black_box_function_call.rs | 12 ++ .../compiler/optimizers/merge_expressions.rs | 202 ++++++++++++++++++ acvm-repo/acvm/src/compiler/optimizers/mod.rs | 2 + .../acvm/src/compiler/transformers/mod.rs | 16 +- 4 files changed, 230 insertions(+), 2 deletions(-) create mode 100644 acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs diff --git a/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs b/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs index 8bb9a680ea9..763302cab23 100644 --- a/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs +++ b/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use crate::native_types::Witness; use crate::{AcirField, BlackBoxFunc}; @@ -423,6 +425,16 @@ impl BlackBoxFuncCall { BlackBoxFuncCall::BigIntToLeBytes { outputs, .. } => outputs.to_vec(), } } + + pub fn get_input_witnesses(&self) -> HashSet { + let mut result = HashSet::new(); + for input in self.get_inputs_vec() { + if let ConstantOrWitnessEnum::Witness(w) = input.input() { + result.insert(w); + } + } + result + } } const ABBREVIATION_LIMIT: usize = 5; diff --git a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs new file mode 100644 index 00000000000..4291b997fde --- /dev/null +++ b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs @@ -0,0 +1,202 @@ +use std::collections::{HashMap, HashSet}; + +use acir::{ + circuit::{brillig::BrilligInputs, directives::Directive, opcodes::BlockId, Circuit, Opcode}, + native_types::{Expression, Witness}, + AcirField, +}; + +pub(crate) struct MergeExpressionsOptimizer { + resolved_blocks: HashMap>, +} + +impl MergeExpressionsOptimizer { + pub(crate) fn new() -> Self { + MergeExpressionsOptimizer { resolved_blocks: HashMap::new() } + } + /// This pass analyzes the circuit and identifies intermediate variables that are + /// only used in two gates. It then merges the gate that produces the + /// intermediate variable into the second one that uses it + /// Note: This pass is only relevant for backends that can handle unlimited width + pub(crate) fn eliminate_intermediate_variable( + &mut self, + circuit: &Circuit, + acir_opcode_positions: Vec, + ) -> (Vec>, Vec) { + // Keep track, for each witness, of the gates that use it + let circuit_inputs = circuit.circuit_arguments(); + self.resolved_blocks = HashMap::new(); + let mut used_witness: HashMap> = HashMap::new(); + for (i, opcode) in circuit.opcodes.iter().enumerate() { + let witnesses = self.witness_inputs(opcode); + if let Opcode::MemoryInit { block_id, .. } = opcode { + self.resolved_blocks.insert(*block_id, witnesses.clone()); + } + for w in witnesses { + // We do not simplify circuit inputs + if !circuit_inputs.contains(&w) { + used_witness.entry(w).or_default().insert(i); + } + } + } + + let mut modified_gates: HashMap> = HashMap::new(); + let mut new_circuit = Vec::new(); + let mut new_acir_opcode_positions = Vec::new(); + // For each opcode, try to get a target opcode to merge with + for (i, opcode) in circuit.opcodes.iter().enumerate() { + if !matches!(opcode, Opcode::AssertZero(_)) { + new_circuit.push(opcode.clone()); + new_acir_opcode_positions.push(acir_opcode_positions[i]); + continue; + } + let opcode = modified_gates.get(&i).unwrap_or(opcode).clone(); + let mut to_keep = true; + let input_witnesses = self.witness_inputs(&opcode); + for w in input_witnesses.clone() { + let empty_gates = HashSet::new(); + let gates_using_w = used_witness.get(&w).unwrap_or(&empty_gates); + // We only consider witness which are used in exactly two arithmetic gates + if gates_using_w.len() == 2 { + let gates_using_w: Vec<_> = gates_using_w.iter().collect(); + let mut b = *gates_using_w[1]; + if b == i { + b = *gates_using_w[0]; + } else { + // sanity check + assert!(i == *gates_using_w[0]); + } + let second_gate = modified_gates.get(&b).unwrap_or(&circuit.opcodes[b]).clone(); + if let (Opcode::AssertZero(expr_define), Opcode::AssertZero(expr_use)) = + (opcode.clone(), second_gate) + { + if let Some(expr) = Self::merge(&expr_use, &expr_define, w) { + // sanity check + assert!(i < b); + modified_gates.insert(b, Opcode::AssertZero(expr)); + to_keep = false; + // Update the 'used_witness' map to account for the merge. + for w2 in Self::expr_wit(&expr_define) { + if !circuit_inputs.contains(&w2) { + let mut v = used_witness[&w2].clone(); + v.insert(b); + v.remove(&i); + used_witness.insert(w2, v); + } + } + // We need to stop here and continue with the next opcode + // because the merge invalidate the current opcode + break; + } + } + } + } + + if to_keep { + if modified_gates.contains_key(&i) { + new_circuit.push(modified_gates[&i].clone()); + } else { + new_circuit.push(opcode.clone()); + } + new_acir_opcode_positions.push(acir_opcode_positions[i]); + } + } + (new_circuit, new_acir_opcode_positions) + } + + fn expr_wit(expr: &Expression) -> HashSet { + let mut result = HashSet::new(); + result.extend(expr.mul_terms.iter().flat_map(|i| vec![i.1, i.2])); + result.extend(expr.linear_combinations.iter().map(|i| i.1)); + result + } + + fn brillig_input_wit(&self, input: &BrilligInputs) -> HashSet { + let mut result = HashSet::new(); + match input { + BrilligInputs::Single(expr) => { + result.extend(Self::expr_wit(expr)); + } + BrilligInputs::Array(exprs) => { + for expr in exprs { + result.extend(Self::expr_wit(expr)); + } + } + BrilligInputs::MemoryArray(block_id) => { + let witnesses = self.resolved_blocks.get(block_id).expect("Unknown block id"); + result.extend(witnesses); + } + } + result + } + + // Returns the input witnesses used by the opcode + fn witness_inputs(&self, opcode: &Opcode) -> HashSet { + let mut witnesses = HashSet::new(); + match opcode { + Opcode::AssertZero(expr) => Self::expr_wit(expr), + Opcode::BlackBoxFuncCall(bb_func) => bb_func.get_input_witnesses(), + Opcode::Directive(Directive::ToLeRadix { a, .. }) => Self::expr_wit(a), + Opcode::MemoryOp { block_id: _, op, predicate } => { + //index et value, et predicate + let mut witnesses = HashSet::new(); + witnesses.extend(Self::expr_wit(&op.index)); + witnesses.extend(Self::expr_wit(&op.value)); + if let Some(p) = predicate { + witnesses.extend(Self::expr_wit(p)); + } + witnesses + } + + Opcode::MemoryInit { block_id: _, init, block_type: _ } => { + init.iter().cloned().collect() + } + Opcode::BrilligCall { inputs, .. } => { + for i in inputs { + witnesses.extend(self.brillig_input_wit(i)); + } + witnesses + } + Opcode::Call { id: _, inputs, outputs: _, predicate } => { + for i in inputs { + witnesses.insert(*i); + } + if let Some(p) = predicate { + witnesses.extend(Self::expr_wit(p)); + } + witnesses + } + } + } + + // Merge 'expr' into 'target' via Gaussian elimination on 'w' + // Returns None if the expressions cannot be merged + fn merge( + target: &Expression, + expr: &Expression, + w: Witness, + ) -> Option> { + // Check that the witness is not part of multiplication terms + for m in &target.mul_terms { + if m.1 == w || m.2 == w { + return None; + } + } + for m in &expr.mul_terms { + if m.1 == w || m.2 == w { + return None; + } + } + + for k in &target.linear_combinations { + if k.1 == w { + for i in &expr.linear_combinations { + if i.1 == w { + return Some(target.add_mul(-(k.0 / i.0), expr)); + } + } + } + } + None + } +} diff --git a/acvm-repo/acvm/src/compiler/optimizers/mod.rs b/acvm-repo/acvm/src/compiler/optimizers/mod.rs index e20ad97a108..1947a80dc35 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/mod.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/mod.rs @@ -5,10 +5,12 @@ use acir::{ // mod constant_backpropagation; mod general; +mod merge_expressions; mod redundant_range; mod unused_memory; pub(crate) use general::GeneralOptimizer; +pub(crate) use merge_expressions::MergeExpressionsOptimizer; pub(crate) use redundant_range::RangeOptimizer; use tracing::info; diff --git a/acvm-repo/acvm/src/compiler/transformers/mod.rs b/acvm-repo/acvm/src/compiler/transformers/mod.rs index 4fd8ba7883f..305e71fc1c1 100644 --- a/acvm-repo/acvm/src/compiler/transformers/mod.rs +++ b/acvm-repo/acvm/src/compiler/transformers/mod.rs @@ -9,7 +9,9 @@ mod csat; pub(crate) use csat::CSatTransformer; -use super::{transform_assert_messages, AcirTransformationMap}; +use super::{ + optimizers::MergeExpressionsOptimizer, transform_assert_messages, AcirTransformationMap, +}; /// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] specific optimizations to a [`Circuit`]. pub fn transform( @@ -165,6 +167,16 @@ pub(super) fn transform_internal( // The transformer does not add new public inputs ..acir }; - + let mut merge_optimizer = MergeExpressionsOptimizer::new(); + let (opcodes, new_acir_opcode_positions) = + merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions); + // n.b. we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less. + let acir = Circuit { + current_witness_index, + expression_width, + opcodes, + // The optimizer does not add new public inputs + ..acir + }; (acir, new_acir_opcode_positions) } From 6a6b59b1090363b502212fcd15578da307f6dd9e Mon Sep 17 00:00:00 2001 From: guipublic Date: Thu, 24 Oct 2024 17:06:10 +0000 Subject: [PATCH 2/3] acir experimental optimization --- acvm-repo/acvm/src/compiler/mod.rs | 9 +- .../compiler/optimizers/merge_expressions.rs | 242 +++++++++++++++--- .../acvm/src/compiler/transformers/mod.rs | 44 +++- compiler/noirc_driver/src/lib.rs | 5 + compiler/noirc_evaluator/src/ssa.rs | 5 +- .../noirc_evaluator/src/ssa/acir_gen/mod.rs | 21 +- compiler/noirc_evaluator/src/ssa/opt/die.rs | 23 +- .../src/ssa/ssa_gen/program.rs | 4 + 8 files changed, 306 insertions(+), 47 deletions(-) diff --git a/acvm-repo/acvm/src/compiler/mod.rs b/acvm-repo/acvm/src/compiler/mod.rs index 5ece3d19a6e..40b1b585ef1 100644 --- a/acvm-repo/acvm/src/compiler/mod.rs +++ b/acvm-repo/acvm/src/compiler/mod.rs @@ -73,11 +73,16 @@ fn transform_assert_messages( pub fn compile( acir: Circuit, expression_width: ExpressionWidth, + experimental_optimization: bool, ) -> (Circuit, AcirTransformationMap) { let (acir, acir_opcode_positions) = optimize_internal(acir); - let (mut acir, acir_opcode_positions) = - transform_internal(acir, expression_width, acir_opcode_positions); + let (mut acir, acir_opcode_positions) = transform_internal( + acir, + expression_width, + acir_opcode_positions, + experimental_optimization, + ); let transformation_map = AcirTransformationMap::new(acir_opcode_positions); diff --git a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs index 4291b997fde..d30efcd4028 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs @@ -6,23 +6,23 @@ use acir::{ AcirField, }; -pub(crate) struct MergeExpressionsOptimizer { +pub(crate) struct MergeExpressionsOptimizer { resolved_blocks: HashMap>, + + modified_gates: HashMap>, + deleted_gates: HashSet, } -impl MergeExpressionsOptimizer { +impl MergeExpressionsOptimizer { pub(crate) fn new() -> Self { - MergeExpressionsOptimizer { resolved_blocks: HashMap::new() } + MergeExpressionsOptimizer { + resolved_blocks: HashMap::new(), + modified_gates: HashMap::new(), + deleted_gates: HashSet::new(), + } } - /// This pass analyzes the circuit and identifies intermediate variables that are - /// only used in two gates. It then merges the gate that produces the - /// intermediate variable into the second one that uses it - /// Note: This pass is only relevant for backends that can handle unlimited width - pub(crate) fn eliminate_intermediate_variable( - &mut self, - circuit: &Circuit, - acir_opcode_positions: Vec, - ) -> (Vec>, Vec) { + + fn compute_used_witness(&mut self, circuit: &Circuit) -> HashMap> { // Keep track, for each witness, of the gates that use it let circuit_inputs = circuit.circuit_arguments(); self.resolved_blocks = HashMap::new(); @@ -39,6 +39,20 @@ impl MergeExpressionsOptimizer { } } } + used_witness + } + + /// This pass analyzes the circuit and identifies intermediate variables that are + /// only used in two gates. It then merges the gate that produces the + /// intermediate variable into the second one that uses it + /// Note: This pass is only relevant for backends that can handle unlimited width + pub(crate) fn eliminate_intermediate_variable( + &mut self, + circuit: &Circuit, + acir_opcode_positions: Vec, + ) -> (Vec>, Vec) { + let mut used_witness = self.compute_used_witness(circuit); + let circuit_inputs = circuit.circuit_arguments(); let mut modified_gates: HashMap> = HashMap::new(); let mut new_circuit = Vec::new(); @@ -104,14 +118,14 @@ impl MergeExpressionsOptimizer { (new_circuit, new_acir_opcode_positions) } - fn expr_wit(expr: &Expression) -> HashSet { + fn expr_wit(expr: &Expression) -> HashSet { let mut result = HashSet::new(); result.extend(expr.mul_terms.iter().flat_map(|i| vec![i.1, i.2])); result.extend(expr.linear_combinations.iter().map(|i| i.1)); result } - fn brillig_input_wit(&self, input: &BrilligInputs) -> HashSet { + fn brillig_input_wit(&self, input: &BrilligInputs) -> HashSet { let mut result = HashSet::new(); match input { BrilligInputs::Single(expr) => { @@ -131,7 +145,7 @@ impl MergeExpressionsOptimizer { } // Returns the input witnesses used by the opcode - fn witness_inputs(&self, opcode: &Opcode) -> HashSet { + fn witness_inputs(&self, opcode: &Opcode) -> HashSet { let mut witnesses = HashSet::new(); match opcode { Opcode::AssertZero(expr) => Self::expr_wit(expr), @@ -169,34 +183,198 @@ impl MergeExpressionsOptimizer { } } - // Merge 'expr' into 'target' via Gaussian elimination on 'w' - // Returns None if the expressions cannot be merged - fn merge( - target: &Expression, - expr: &Expression, - w: Witness, - ) -> Option> { - // Check that the witness is not part of multiplication terms - for m in &target.mul_terms { + /// Merge 'expr' into 'target' via Gaussian elimination on 'w' + /// It supports the case where w is in a target's multiplication term: + /// - If w is only linear in expr and target, it's just a Gaussian elimination + /// - If w is in a expr's mul term: merge is not allowed + /// - If w is in a target's mul term AND expr has no mul term, then we do the Gaussian elimination in target's linear and mul terms + fn merge(target: &Expression, expr: &Expression, w: Witness) -> Option> { + // Check that the witness is not part of expr multiplication terms + for m in &expr.mul_terms { if m.1 == w || m.2 == w { return None; } } - for m in &expr.mul_terms { - if m.1 == w || m.2 == w { - return None; + // w must be in expr linear terms, we use expr to 'solve w' + let mut solved_w = Expression::zero(); + let w_idx = expr.linear_combinations.iter().position(|x| x.1 == w).unwrap(); + solved_w.linear_combinations.push((F::one(), w)); + solved_w = solved_w.add_mul(-(F::one() / expr.linear_combinations[w_idx].0), expr); + + // Solve w in target multiplication terms + let mut result: Expression = Expression::zero(); + result.linear_combinations = target.linear_combinations.clone(); + result.q_c = target.q_c; + for mul in &target.mul_terms { + if mul.1 == w || mul.2 == w { + if !expr.mul_terms.is_empty() || mul.1 == mul.2 { + // the result will be of degree 3, so this case does not work + return None; + } else { + let x = if mul.1 == w { mul.2 } else { mul.1 }; + + // replace w by solved_w in the mul: x * w = x * solved_w + let mut solved_mul = Expression::zero(); + for lin in &solved_w.linear_combinations { + solved_mul.mul_terms.push((mul.0 * lin.0, x, lin.1)); + } + solved_mul.linear_combinations.push((solved_w.q_c, x)); + solved_mul.sort(); + result = result.add_mul(F::one(), &solved_mul); + } + } else { + result.mul_terms.push(*mul); + result.sort(); } } - for k in &target.linear_combinations { + // Solve w in target linear terms + let mut w_coefficient = F::zero(); + for k in &result.linear_combinations { if k.1 == w { - for i in &expr.linear_combinations { - if i.1 == w { - return Some(target.add_mul(-(k.0 / i.0), expr)); + w_coefficient = -(k.0 / expr.linear_combinations[w_idx].0); + break; + } + } + result = result.add_mul(w_coefficient, expr); + Some(result) + } + + fn is_free(opcode: Opcode, width: usize) -> Option> { + if let Opcode::AssertZero(expr) = opcode { + if expr.mul_terms.len() <= 1 + && expr.linear_combinations.len() < width + && !expr.linear_combinations.is_empty() + { + return Some(expr); + } + } + None + } + + fn get_opcode(&self, g: usize, circuit: &Circuit) -> Option> { + if self.deleted_gates.contains(&g) { + return None; + } + Some(self.modified_gates.get(&g).unwrap_or(&circuit.opcodes[g]).clone()) + } + + fn fits(expr: &Expression, width: usize) -> bool { + if expr.mul_terms.len() > 1 || expr.linear_combinations.len() > width { + return false; + } + if expr.mul_terms.len() == 1 { + let mut used = 2; + let mut contains_a = false; + let mut contains_b = false; + for lin in &expr.linear_combinations { + if lin.1 == expr.mul_terms[0].1 { + contains_a = true; + } + if lin.1 == expr.mul_terms[0].2 { + contains_b = true; + } + if contains_a && contains_b { + break; + } + } + if contains_a { + used -= 1; + } + if (expr.mul_terms[0].1 != expr.mul_terms[0].2) && contains_b { + used -= 1; + } + return expr.linear_combinations.len() + used <= width; + } + true + } + + /// Simplify 'small expression' + /// Small expressions, even if they are re-used several times in other expressions, can still be simplified. + /// for example in the case where we have c=ab and the expressions using c do not have a multiplication term: c = ab; a+b+c =0; d+e-c = 0; + /// Then it can be simplified into two expressions: ab+a+c=0; -ab+d+e=0; + /// + /// If we enforce that ALL results satisfies the width, then we are ensured that it will always be an improvement. + /// However in practice the improvement is very small, so instead we allow for some over-fitting. As a result, optimisation is not guaranteed + /// and in some cases the result can be worse than the original circuit. + pub(crate) fn simply_small_expression( + &mut self, + circuit: &Circuit, + acir_opcode_positions: Vec, + width: usize, + ) -> (Vec>, Vec) { + let mut used_witness = self.compute_used_witness(circuit); + + let mut new_circuit = Vec::new(); + let mut new_acir_opcode_positions = Vec::new(); + self.modified_gates.clear(); + self.deleted_gates.clear(); + + // For each opcode, we try to simplify 'small' expressions + // If it works, we update modified_gates and deleted_gates to store the result of the simplification + for (i, _) in circuit.opcodes.iter().enumerate() { + let mut to_keep = true; + if let Some(opcode) = self.get_opcode(i, circuit) { + let mut merged = Vec::new(); + let empty_gates = HashSet::new(); + + // If the current expression current_expr is a 'small' expression + if let Some(current_expr) = Self::is_free(opcode.clone(), width) { + // we try to simplify it doing Gaussian elimination on one of its linear witness + // We try each witness until a simplification works. + for (_, w) in ¤t_expr.linear_combinations { + let gates_using_w = used_witness.get(w).unwrap_or(&empty_gates).clone(); + let gates: Vec<&usize> = gates_using_w + .iter() + .filter(|g| **g != i && !self.deleted_gates.contains(g)) + .collect(); + merged.clear(); + for g in gates { + if let Some(g_update) = self.get_opcode(*g, circuit) { + if let Opcode::AssertZero(g_expr) = g_update.clone() { + let merged_expr = Self::merge(&g_expr, ¤t_expr, *w); + if merged_expr.is_none() + || !Self::fits(&merged_expr.clone().unwrap(), width * 2) + { + // Do not simplify if merge failed or the result does not fit + to_keep = true; + break; + } + if *g <= i { + // This case is not supported, as it would break gates execution ordering + to_keep = true; + break; + } + merged.push((*g, merged_expr.clone().unwrap())); + } else { + // Do not simplify if w is used in a non-arithmetic opcode + to_keep = true; + break; + } + } + } + } + if !to_keep { + for m in &merged { + self.modified_gates.insert(m.0, Opcode::AssertZero(m.1.clone())); + // Update the used_witness map + let expr_witnesses = Self::expr_wit(&m.1); + for w in expr_witnesses { + used_witness.entry(w).or_default().insert(m.0); + } + } + self.deleted_gates.insert(i); } } } } - None + #[allow(clippy::needless_range_loop)] + for i in 0..circuit.opcodes.len() { + if let Some(op) = self.get_opcode(i, circuit) { + new_circuit.push(op); + new_acir_opcode_positions.push(acir_opcode_positions[i]); + } + } + (new_circuit, new_acir_opcode_positions) } } diff --git a/acvm-repo/acvm/src/compiler/transformers/mod.rs b/acvm-repo/acvm/src/compiler/transformers/mod.rs index 305e71fc1c1..2a76c99c552 100644 --- a/acvm-repo/acvm/src/compiler/transformers/mod.rs +++ b/acvm-repo/acvm/src/compiler/transformers/mod.rs @@ -17,13 +17,18 @@ use super::{ pub fn transform( acir: Circuit, expression_width: ExpressionWidth, + experimental_optimization: bool, ) -> (Circuit, AcirTransformationMap) { // Track original acir opcode positions throughout the transformation passes of the compilation // by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert) let acir_opcode_positions = acir.opcodes.iter().enumerate().map(|(i, _)| i).collect(); - let (mut acir, acir_opcode_positions) = - transform_internal(acir, expression_width, acir_opcode_positions); + let (mut acir, acir_opcode_positions) = transform_internal( + acir, + expression_width, + acir_opcode_positions, + experimental_optimization, + ); let transformation_map = AcirTransformationMap::new(acir_opcode_positions); @@ -40,8 +45,9 @@ pub(super) fn transform_internal( acir: Circuit, expression_width: ExpressionWidth, acir_opcode_positions: Vec, + experimental_optimization: bool, ) -> (Circuit, Vec) { - let mut transformer = match &expression_width { + let (mut transformer, width) = match &expression_width { ExpressionWidth::Unbounded => { return (acir, acir_opcode_positions); } @@ -50,10 +56,38 @@ pub(super) fn transform_internal( for value in acir.circuit_arguments() { csat.mark_solvable(value); } - csat + (csat, width) } }; + let current_witness_index = acir.current_witness_index; + let mut merge_optimizer = MergeExpressionsOptimizer::new(); + + if experimental_optimization { + let (opcodes, new_acir_opcode_positions) = + merge_optimizer.simply_small_expression(&acir, acir_opcode_positions, *width); + let acir = Circuit { + current_witness_index, + expression_width, + opcodes, + // The optimizer does not add new public inputs + ..acir + }; + + let (opcodes, new_acir_opcode_positions) = + merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions); + + // n.b. we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less. + let acir = Circuit { + current_witness_index, + expression_width, + opcodes, + // The optimizer does not add new public inputs + ..acir + }; + return (acir, new_acir_opcode_positions); + } + // TODO: the code below is only for CSAT transformer // TODO it may be possible to refactor it in a way that we do not need to return early from the r1cs // TODO or at the very least, we could put all of it inside of CSatOptimizer pass @@ -167,7 +201,7 @@ pub(super) fn transform_internal( // The transformer does not add new public inputs ..acir }; - let mut merge_optimizer = MergeExpressionsOptimizer::new(); + let (opcodes, new_acir_opcode_positions) = merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions); // n.b. we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less. diff --git a/compiler/noirc_driver/src/lib.rs b/compiler/noirc_driver/src/lib.rs index 2f0122524eb..db5568dad33 100644 --- a/compiler/noirc_driver/src/lib.rs +++ b/compiler/noirc_driver/src/lib.rs @@ -124,6 +124,10 @@ pub struct CompileOptions { /// This check should always be run on production code. #[arg(long)] pub skip_underconstrained_check: bool, + + /// Flag to enable experimental ACIR optimizations + #[arg(long, default_value = "false")] + pub experimental_optimization: bool, } pub fn parse_expression_width(input: &str) -> Result { @@ -580,6 +584,7 @@ pub fn compile_no_check( }, emit_ssa: if options.emit_ssa { Some(context.package_build_path.clone()) } else { None }, skip_underconstrained_check: options.skip_underconstrained_check, + experimental_optimization: options.experimental_optimization, }; let SsaProgramArtifact { program, debug, warnings, names, brillig_names, error_types, .. } = diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index efc7c6018c1..62df52c3806 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -67,6 +67,9 @@ pub struct SsaEvaluatorOptions { /// Skip the check for under constrained values pub skip_underconstrained_check: bool, + + /// Enable experimental ACIR optimizations + pub experimental_optimization: bool, } pub(crate) struct ArtifactsAndWarnings(Artifacts, Vec); @@ -137,7 +140,7 @@ pub(crate) fn optimize_into_acir( }); let artifacts = time("SSA to ACIR", options.print_codegen_timings, || { - ssa.into_acir(&brillig, options.expression_width) + ssa.into_acir(&brillig, options.expression_width, options.experimental_optimization) })?; Ok(ArtifactsAndWarnings(artifacts, ssa_level_warnings)) } diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs index b560fafd337..9fd50f8aa38 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -195,6 +195,11 @@ struct Context<'a> { /// Contains state that is generated and also used across ACIR functions shared_context: &'a mut SharedContext, + + /// Set of values used in multiple expressions + common_used: fxhash::FxHashSet, + + experimental_optimization: bool, } #[derive(Clone)] @@ -292,12 +297,14 @@ impl Ssa { self, brillig: &Brillig, expression_width: ExpressionWidth, + experimental_optimization: bool, ) -> Result { let mut acirs = Vec::new(); // TODO: can we parallelize this? let mut shared_context = SharedContext::default(); for function in self.functions.values() { - let context = Context::new(&mut shared_context, expression_width); + let context = + Context::new(&mut shared_context, expression_width, experimental_optimization); if let Some(mut generated_acir) = context.convert_ssa_function(&self, function, brillig)? { @@ -349,6 +356,7 @@ impl<'a> Context<'a> { fn new( shared_context: &'a mut SharedContext, expression_width: ExpressionWidth, + experimental_optimization: bool, ) -> Context<'a> { let mut acir_context = AcirContext::default(); acir_context.set_expression_width(expression_width); @@ -365,6 +373,8 @@ impl<'a> Context<'a> { max_block_id: 0, data_bus: DataBus::default(), shared_context, + common_used: fxhash::FxHashSet::default(), + experimental_optimization, } } @@ -664,6 +674,7 @@ impl<'a> Context<'a> { let instruction = &dfg[instruction_id]; self.acir_context.set_call_stack(dfg.get_call_stack(instruction_id)); let mut warnings = Vec::new(); + self.common_used = ssa.common_values.to_owned(); match instruction { Instruction::Binary(binary) => { let result_acir_var = self.convert_ssa_binary(binary, dfg)?; @@ -1928,7 +1939,13 @@ impl<'a> Context<'a> { dfg: &DataFlowGraph, ) -> Result { match self.convert_value(value_id, dfg) { - AcirValue::Var(acir_var, _) => Ok(acir_var), + AcirValue::Var(acir_var, _) => { + if self.experimental_optimization && self.common_used.contains(&value_id) { + self.acir_context.get_or_create_witness_var(acir_var) + } else { + Ok(acir_var) + } + } AcirValue::Array(array) => Err(InternalError::Unexpected { expected: "a numeric value".to_string(), found: format!("{array:?}"), diff --git a/compiler/noirc_evaluator/src/ssa/opt/die.rs b/compiler/noirc_evaluator/src/ssa/opt/die.rs index beca7c41e5c..9fe50b1484b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/die.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/die.rs @@ -25,7 +25,10 @@ impl Ssa { #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn dead_instruction_elimination(mut self) -> Ssa { for function in self.functions.values_mut() { - function.dead_instruction_elimination(true); + let common_values = function.dead_instruction_elimination(true); + if function.id() == self.main_id { + self.common_values = common_values; + } } self } @@ -38,7 +41,10 @@ impl Function { /// instructions that reference results from an instruction in another block are evaluated first. /// If we did not iterate blocks in this order we could not safely say whether or not the results /// of its instructions are needed elsewhere. - pub(crate) fn dead_instruction_elimination(&mut self, insert_out_of_bounds_checks: bool) { + pub(crate) fn dead_instruction_elimination( + &mut self, + insert_out_of_bounds_checks: bool, + ) -> HashSet { let mut context = Context::default(); for call_data in &self.dfg.data_bus.call_data { context.mark_used_instruction_results(&self.dfg, call_data.array_id); @@ -59,11 +65,13 @@ impl Function { // instructions (we don't want to remove those checks, or instructions that are // dependencies of those checks) if inserted_out_of_bounds_checks { + context.common_values.clear(); self.dead_instruction_elimination(false); - return; + return context.common_values; } - + let result = context.common_values.clone(); context.remove_rc_instructions(&mut self.dfg); + result } } @@ -77,6 +85,9 @@ struct Context { /// they technically contain side-effects but we still want to remove them if their /// `value` parameter is not used elsewhere. rc_instructions: Vec<(InstructionId, BasicBlockId)>, + + /// Values that are used by multiple instructions + common_values: HashSet, } impl Context { @@ -198,7 +209,9 @@ impl Context { let value_id = dfg.resolve(value_id); match &dfg[value_id] { Value::Instruction { .. } => { - self.used_values.insert(value_id); + if !self.used_values.insert(value_id) { + self.common_values.insert(value_id); + } } Value::Array { array, .. } => { self.used_values.insert(value_id); diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs index fe786da16ca..f4d73d30103 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs @@ -8,6 +8,7 @@ use serde_with::serde_as; use crate::ssa::ir::{ function::{Function, FunctionId, RuntimeType}, map::AtomicCounter, + value::ValueId, }; use noirc_frontend::hir_def::types::Type as HirType; @@ -30,6 +31,8 @@ pub(crate) struct Ssa { // ABI not the actual SSA IR. #[serde(skip)] pub(crate) error_selector_to_type: BTreeMap, + + pub(crate) common_values: fxhash::FxHashSet, } impl Ssa { @@ -67,6 +70,7 @@ impl Ssa { next_id: AtomicCounter::starting_after(max_id), entry_point_to_generated_index, error_selector_to_type: error_types, + common_values: fxhash::FxHashSet::default(), } } From 6eeeb7b1ef28def1bcc2269f62c4cc16118230d1 Mon Sep 17 00:00:00 2001 From: guipublic Date: Fri, 25 Oct 2024 08:44:01 +0000 Subject: [PATCH 3/3] add changes for the compiler option --- tooling/lsp/src/requests/profile_run.rs | 4 ++-- tooling/nargo/src/ops/transform.rs | 14 +++++++++++--- tooling/nargo_cli/src/cli/compile_cmd.rs | 12 ++++++++++-- tooling/nargo_cli/src/cli/dap_cmd.rs | 2 +- tooling/nargo_cli/src/cli/debug_cmd.rs | 6 +++++- 5 files changed, 29 insertions(+), 9 deletions(-) diff --git a/tooling/lsp/src/requests/profile_run.rs b/tooling/lsp/src/requests/profile_run.rs index a7362300adc..2dc88c4d11b 100644 --- a/tooling/lsp/src/requests/profile_run.rs +++ b/tooling/lsp/src/requests/profile_run.rs @@ -83,7 +83,7 @@ fn on_profile_run_request_inner( let mut file_map: BTreeMap = BTreeMap::new(); for compiled_program in compiled_programs { let compiled_program = - nargo::ops::transform_program(compiled_program, expression_width); + nargo::ops::transform_program(compiled_program, expression_width, false); for function_debug in compiled_program.debug.iter() { let span_opcodes = function_debug.count_span_opcodes(); @@ -95,7 +95,7 @@ fn on_profile_run_request_inner( for compiled_contract in compiled_contracts { let compiled_contract = - nargo::ops::transform_contract(compiled_contract, expression_width); + nargo::ops::transform_contract(compiled_contract, expression_width, false); let function_debug_info = compiled_contract .functions diff --git a/tooling/nargo/src/ops/transform.rs b/tooling/nargo/src/ops/transform.rs index 9255ac3e0ec..4f3d2f29bcb 100644 --- a/tooling/nargo/src/ops/transform.rs +++ b/tooling/nargo/src/ops/transform.rs @@ -9,11 +9,13 @@ use noirc_errors::debug_info::DebugInfo; pub fn transform_program( mut compiled_program: CompiledProgram, expression_width: ExpressionWidth, + experimental_optimization: bool, ) -> CompiledProgram { compiled_program.program = transform_program_internal( compiled_program.program, &mut compiled_program.debug, expression_width, + experimental_optimization, ); compiled_program } @@ -21,10 +23,15 @@ pub fn transform_program( pub fn transform_contract( contract: CompiledContract, expression_width: ExpressionWidth, + experimental_optimization: bool, ) -> CompiledContract { let functions = vecmap(contract.functions, |mut func| { - func.bytecode = - transform_program_internal(func.bytecode, &mut func.debug, expression_width); + func.bytecode = transform_program_internal( + func.bytecode, + &mut func.debug, + expression_width, + experimental_optimization, + ); func }); @@ -36,6 +43,7 @@ fn transform_program_internal( mut program: Program, debug: &mut [DebugInfo], expression_width: ExpressionWidth, + experimental_optimization: bool, ) -> Program { let functions = std::mem::take(&mut program.functions); @@ -44,7 +52,7 @@ fn transform_program_internal( .enumerate() .map(|(i, function)| { let (optimized_circuit, location_map) = - acvm::compiler::compile(function, expression_width); + acvm::compiler::compile(function, expression_width, experimental_optimization); debug[i].update_acir(location_map); optimized_circuit }) diff --git a/tooling/nargo_cli/src/cli/compile_cmd.rs b/tooling/nargo_cli/src/cli/compile_cmd.rs index 0ad07c91ff4..a7c4ecdeba9 100644 --- a/tooling/nargo_cli/src/cli/compile_cmd.rs +++ b/tooling/nargo_cli/src/cli/compile_cmd.rs @@ -191,7 +191,11 @@ fn compile_programs( let target_width = get_target_width(package.expression_width, compile_options.expression_width); - let program = nargo::ops::transform_program(program, target_width); + let program = nargo::ops::transform_program( + program, + target_width, + compile_options.experimental_optimization, + ); save_program_to_file(&program.into(), &package.name, workspace.target_directory_path()); @@ -222,7 +226,11 @@ fn compiled_contracts( compile_contract(file_manager, parsed_files, package, compile_options)?; let target_width = get_target_width(package.expression_width, compile_options.expression_width); - let contract = nargo::ops::transform_contract(contract, target_width); + let contract = nargo::ops::transform_contract( + contract, + target_width, + compile_options.experimental_optimization, + ); save_contract(contract, package, target_dir, compile_options.show_artifact_paths); Ok(((), warnings)) }) diff --git a/tooling/nargo_cli/src/cli/dap_cmd.rs b/tooling/nargo_cli/src/cli/dap_cmd.rs index a84e961cfe7..7032facfbd2 100644 --- a/tooling/nargo_cli/src/cli/dap_cmd.rs +++ b/tooling/nargo_cli/src/cli/dap_cmd.rs @@ -119,7 +119,7 @@ fn load_and_compile_project( ) .map_err(|_| LoadError::Generic("Failed to compile project".into()))?; - let compiled_program = nargo::ops::transform_program(compiled_program, expression_width); + let compiled_program = nargo::ops::transform_program(compiled_program, expression_width, false); let (inputs_map, _) = read_inputs_from_file(&package.root_dir, prover_name, Format::Toml, &compiled_program.abi) diff --git a/tooling/nargo_cli/src/cli/debug_cmd.rs b/tooling/nargo_cli/src/cli/debug_cmd.rs index e837f297475..2d26aeca5ce 100644 --- a/tooling/nargo_cli/src/cli/debug_cmd.rs +++ b/tooling/nargo_cli/src/cli/debug_cmd.rs @@ -83,7 +83,11 @@ pub(crate) fn run(args: DebugCommand, config: NargoConfig) -> Result<(), CliErro let target_width = get_target_width(package.expression_width, args.compile_options.expression_width); - let compiled_program = nargo::ops::transform_program(compiled_program, target_width); + let compiled_program = nargo::ops::transform_program( + compiled_program, + target_width, + args.compile_options.experimental_optimization, + ); run_async(package, compiled_program, &args.prover_name, &args.witness_name, target_dir) }