Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: experimental acir optimisation #6341

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions acvm-repo/acvm/src/compiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,16 @@ fn transform_assert_messages<F: Clone>(
pub fn compile<F: AcirField>(
acir: Circuit<F>,
expression_width: ExpressionWidth,
experimental_optimization: bool,
) -> (Circuit<F>, AcirTransformationMap) {
let (acir, acir_opcode_positions) = optimize_internal(acir);

let (mut acir, acir_opcode_positions) =
transform_internal(acir, expression_width, acir_opcode_positions);
let (mut acir, acir_opcode_positions) = transform_internal(
acir,
expression_width,
acir_opcode_positions,
experimental_optimization,
);

let transformation_map = AcirTransformationMap::new(acir_opcode_positions);

Expand Down
247 changes: 216 additions & 31 deletions acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,23 @@

use crate::compiler::CircuitSimulator;

pub(crate) struct MergeExpressionsOptimizer {
pub(crate) struct MergeExpressionsOptimizer<F> {
resolved_blocks: HashMap<BlockId, BTreeSet<Witness>>,

modified_gates: HashMap<usize, Opcode<F>>,
deleted_gates: BTreeSet<usize>,
}

impl MergeExpressionsOptimizer {
impl<F: AcirField> MergeExpressionsOptimizer<F> {
pub(crate) fn new() -> Self {
MergeExpressionsOptimizer { resolved_blocks: HashMap::new() }
MergeExpressionsOptimizer {
resolved_blocks: HashMap::new(),
modified_gates: HashMap::new(),
deleted_gates: BTreeSet::new(),
}
}
/// This pass analyzes the circuit and identifies intermediate variables that are
/// only used in two gates. It then merges the gate that produces the
/// intermediate variable into the second one that uses it
/// Note: This pass is only relevant for backends that can handle unlimited width
pub(crate) fn eliminate_intermediate_variable<F: AcirField>(
&mut self,
circuit: &Circuit<F>,
acir_opcode_positions: Vec<usize>,
) -> (Vec<Opcode<F>>, Vec<usize>) {

fn compute_used_witness(&mut self, circuit: &Circuit<F>) -> BTreeMap<Witness, BTreeSet<usize>> {
// Keep track, for each witness, of the gates that use it
let circuit_inputs = circuit.circuit_arguments();
self.resolved_blocks = HashMap::new();
Expand All @@ -41,6 +41,20 @@
}
}
}
used_witness
}

/// This pass analyzes the circuit and identifies intermediate variables that are
/// only used in two gates. It then merges the gate that produces the
/// intermediate variable into the second one that uses it
/// Note: This pass is only relevant for backends that can handle unlimited width
pub(crate) fn eliminate_intermediate_variable(
&mut self,
circuit: &Circuit<F>,
acir_opcode_positions: Vec<usize>,
) -> (Vec<Opcode<F>>, Vec<usize>) {
let mut used_witness = self.compute_used_witness(circuit);
let circuit_inputs = circuit.circuit_arguments();

let mut modified_gates: HashMap<usize, Opcode<F>> = HashMap::new();
let mut new_circuit = Vec::new();
Expand Down Expand Up @@ -106,7 +120,14 @@
(new_circuit, new_acir_opcode_positions)
}

fn brillig_input_wit<F>(&self, input: &BrilligInputs<F>) -> BTreeSet<Witness> {
fn expr_wit(expr: &Expression<F>) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
result.extend(expr.mul_terms.iter().flat_map(|i| vec![i.1, i.2]));
result.extend(expr.linear_combinations.iter().map(|i| i.1));
result
}

fn brillig_input_wit(&self, input: &BrilligInputs<F>) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
match input {
BrilligInputs::Single(expr) => {
Expand All @@ -126,7 +147,7 @@
}

// Returns the input witnesses used by the opcode
fn witness_inputs<F: AcirField>(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
fn witness_inputs(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
let mut witnesses = BTreeSet::new();
match opcode {
Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr),
Expand Down Expand Up @@ -164,34 +185,198 @@
}
}

// Merge 'expr' into 'target' via Gaussian elimination on 'w'
// Returns None if the expressions cannot be merged
fn merge<F: AcirField>(
target: &Expression<F>,
expr: &Expression<F>,
w: Witness,
) -> Option<Expression<F>> {
// Check that the witness is not part of multiplication terms
for m in &target.mul_terms {
/// Merge 'expr' into 'target' via Gaussian elimination on 'w'
/// It supports the case where w is in a target's multiplication term:
/// - If w is only linear in expr and target, it's just a Gaussian elimination
/// - If w is in a expr's mul term: merge is not allowed
/// - If w is in a target's mul term AND expr has no mul term, then we do the Gaussian elimination in target's linear and mul terms
fn merge(target: &Expression<F>, expr: &Expression<F>, w: Witness) -> Option<Expression<F>> {
// Check that the witness is not part of expr multiplication terms
for m in &expr.mul_terms {
if m.1 == w || m.2 == w {
return None;
}
}
for m in &expr.mul_terms {
if m.1 == w || m.2 == w {
return None;
// w must be in expr linear terms, we use expr to 'solve w'
let mut solved_w = Expression::zero();
let w_idx = expr.linear_combinations.iter().position(|x| x.1 == w).unwrap();
solved_w.linear_combinations.push((F::one(), w));
solved_w = solved_w.add_mul(-(F::one() / expr.linear_combinations[w_idx].0), expr);

// Solve w in target multiplication terms
let mut result: Expression<F> = Expression::zero();
result.linear_combinations = target.linear_combinations.clone();
result.q_c = target.q_c;
for mul in &target.mul_terms {
if mul.1 == w || mul.2 == w {
if !expr.mul_terms.is_empty() || mul.1 == mul.2 {
// the result will be of degree 3, so this case does not work
return None;
} else {
let x = if mul.1 == w { mul.2 } else { mul.1 };

// replace w by solved_w in the mul: x * w = x * solved_w
let mut solved_mul = Expression::zero();
for lin in &solved_w.linear_combinations {
solved_mul.mul_terms.push((mul.0 * lin.0, x, lin.1));
}
solved_mul.linear_combinations.push((solved_w.q_c, x));
solved_mul.sort();
result = result.add_mul(F::one(), &solved_mul);
}
} else {
result.mul_terms.push(*mul);
result.sort();
}
}

for k in &target.linear_combinations {
// Solve w in target linear terms
let mut w_coefficient = F::zero();
for k in &result.linear_combinations {
if k.1 == w {
for i in &expr.linear_combinations {
if i.1 == w {
return Some(target.add_mul(-(k.0 / i.0), expr));
w_coefficient = -(k.0 / expr.linear_combinations[w_idx].0);
break;
}
}
result = result.add_mul(w_coefficient, expr);
Some(result)
}

fn is_free(opcode: Opcode<F>, width: usize) -> Option<Expression<F>> {
if let Opcode::AssertZero(expr) = opcode {
if expr.mul_terms.len() <= 1
&& expr.linear_combinations.len() < width
&& !expr.linear_combinations.is_empty()
{
return Some(expr);
}
}
None
}

fn get_opcode(&self, g: usize, circuit: &Circuit<F>) -> Option<Opcode<F>> {
if self.deleted_gates.contains(&g) {
return None;
}
Some(self.modified_gates.get(&g).unwrap_or(&circuit.opcodes[g]).clone())
}

fn fits(expr: &Expression<F>, width: usize) -> bool {
if expr.mul_terms.len() > 1 || expr.linear_combinations.len() > width {
return false;
}
if expr.mul_terms.len() == 1 {
let mut used = 2;
let mut contains_a = false;
let mut contains_b = false;
for lin in &expr.linear_combinations {
if lin.1 == expr.mul_terms[0].1 {
contains_a = true;
}
if lin.1 == expr.mul_terms[0].2 {
contains_b = true;
}
if contains_a && contains_b {
break;
}
}
if contains_a {
used -= 1;
}
if (expr.mul_terms[0].1 != expr.mul_terms[0].2) && contains_b {
used -= 1;
}
return expr.linear_combinations.len() + used <= width;
}
true
}

/// Simplify 'small expression'
/// Small expressions, even if they are re-used several times in other expressions, can still be simplified.
/// for example in the case where we have c=ab and the expressions using c do not have a multiplication term: c = ab; a+b+c =0; d+e-c = 0;
/// Then it can be simplified into two expressions: ab+a+c=0; -ab+d+e=0;
///
/// If we enforce that ALL results satisfies the width, then we are ensured that it will always be an improvement.
/// However in practice the improvement is very small, so instead we allow for some over-fitting. As a result, optimisation is not guaranteed

Check warning on line 300 in acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (optimisation)
/// and in some cases the result can be worse than the original circuit.
pub(crate) fn simply_small_expression(
&mut self,
circuit: &Circuit<F>,
acir_opcode_positions: Vec<usize>,
width: usize,
) -> (Vec<Opcode<F>>, Vec<usize>) {
let mut used_witness = self.compute_used_witness(circuit);

let mut new_circuit = Vec::new();
let mut new_acir_opcode_positions = Vec::new();
self.modified_gates.clear();
self.deleted_gates.clear();

// For each opcode, we try to simplify 'small' expressions
// If it works, we update modified_gates and deleted_gates to store the result of the simplification
for (i, _) in circuit.opcodes.iter().enumerate() {
let mut to_keep = true;
if let Some(opcode) = self.get_opcode(i, circuit) {
let mut merged = Vec::new();
let empty_gates = BTreeSet::new();

// If the current expression current_expr is a 'small' expression
if let Some(current_expr) = Self::is_free(opcode.clone(), width) {
// we try to simplify it doing Gaussian elimination on one of its linear witness
// We try each witness until a simplification works.
for (_, w) in &current_expr.linear_combinations {
let gates_using_w = used_witness.get(w).unwrap_or(&empty_gates).clone();
let gates: Vec<&usize> = gates_using_w
.iter()
.filter(|g| **g != i && !self.deleted_gates.contains(g))
.collect();
merged.clear();
for g in gates {
if let Some(g_update) = self.get_opcode(*g, circuit) {
if let Opcode::AssertZero(g_expr) = g_update.clone() {
let merged_expr = Self::merge(&g_expr, &current_expr, *w);
if merged_expr.is_none()
|| !Self::fits(&merged_expr.clone().unwrap(), width * 2)
{
// Do not simplify if merge failed or the result does not fit
to_keep = true;
break;
}
if *g <= i {
// This case is not supported, as it would break gates execution ordering
to_keep = true;
break;
}
merged.push((*g, merged_expr.clone().unwrap()));
} else {
// Do not simplify if w is used in a non-arithmetic opcode
to_keep = true;
break;
}
}
}
}
if !to_keep {
for m in &merged {
self.modified_gates.insert(m.0, Opcode::AssertZero(m.1.clone()));
// Update the used_witness map
let expr_witnesses = Self::expr_wit(&m.1);
for w in expr_witnesses {
used_witness.entry(w).or_default().insert(m.0);
}
}
self.deleted_gates.insert(i);
}
}
}
}
None
#[allow(clippy::needless_range_loop)]
for i in 0..circuit.opcodes.len() {
if let Some(op) = self.get_opcode(i, circuit) {
new_circuit.push(op);
new_acir_opcode_positions.push(acir_opcode_positions[i]);
}
}
(new_circuit, new_acir_opcode_positions)
}
}
43 changes: 38 additions & 5 deletions acvm-repo/acvm/src/compiler/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@ use super::{
pub fn transform<F: AcirField>(
acir: Circuit<F>,
expression_width: ExpressionWidth,
experimental_optimization: bool,
) -> (Circuit<F>, AcirTransformationMap) {
// Track original acir opcode positions throughout the transformation passes of the compilation
// by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert)
let acir_opcode_positions = acir.opcodes.iter().enumerate().map(|(i, _)| i).collect();

let (mut acir, acir_opcode_positions) =
transform_internal(acir, expression_width, acir_opcode_positions);
let (mut acir, acir_opcode_positions) = transform_internal(
acir,
expression_width,
acir_opcode_positions,
experimental_optimization,
);

let transformation_map = AcirTransformationMap::new(acir_opcode_positions);

Expand All @@ -41,8 +46,9 @@ pub(super) fn transform_internal<F: AcirField>(
acir: Circuit<F>,
expression_width: ExpressionWidth,
acir_opcode_positions: Vec<usize>,
experimental_optimization: bool,
) -> (Circuit<F>, Vec<usize>) {
let mut transformer = match &expression_width {
let (mut transformer, width) = match &expression_width {
ExpressionWidth::Unbounded => {
return (acir, acir_opcode_positions);
}
Expand All @@ -51,10 +57,38 @@ pub(super) fn transform_internal<F: AcirField>(
for value in acir.circuit_arguments() {
csat.mark_solvable(value);
}
csat
(csat, width)
}
};

let current_witness_index = acir.current_witness_index;
let mut merge_optimizer = MergeExpressionsOptimizer::new();

if experimental_optimization {
let (opcodes, new_acir_opcode_positions) =
merge_optimizer.simply_small_expression(&acir, acir_opcode_positions, *width);
let acir = Circuit {
current_witness_index,
expression_width,
opcodes,
// The optimizer does not add new public inputs
..acir
};

let (opcodes, new_acir_opcode_positions) =
merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions);

// n.b. we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less.
let acir = Circuit {
current_witness_index,
expression_width,
opcodes,
// The optimizer does not add new public inputs
..acir
};
return (acir, new_acir_opcode_positions);
}

// TODO: the code below is only for CSAT transformer
// TODO it may be possible to refactor it in a way that we do not need to return early from the r1cs
// TODO or at the very least, we could put all of it inside of CSatOptimizer pass
Expand Down Expand Up @@ -168,7 +202,6 @@ pub(super) fn transform_internal<F: AcirField>(
// The transformer does not add new public inputs
..acir
};
let mut merge_optimizer = MergeExpressionsOptimizer::new();
let (opcodes, new_acir_opcode_positions) =
merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions);
// n.b. we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less.
Expand Down
Loading
Loading