Skip to content

Commit

Permalink
feat: update circuit evaluation (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
brech1 authored Jul 23, 2024
1 parent a697063 commit 5c673e2
Show file tree
Hide file tree
Showing 5 changed files with 418 additions and 263 deletions.
6 changes: 1 addition & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ serde = { version = "1.0.196", features = ["derive"] }
thiserror = "1.0.59"
strum_macros = "0.26.4"
strum = "0.26.2"
sim-circuit = { git = "https://github.com/brech1/sim-circuit" }

# DSL
circom-circom_algebra = { git = "https://github.com/iden3/circom", package = "circom_algebra" }
Expand All @@ -29,8 +30,3 @@ circom-dag = { git = "https://github.com/iden3/circom", package = "dag" }
circom-parser = { git = "https://github.com/iden3/circom", package = "parser" }
circom-program_structure = { git = "https://github.com/iden3/circom", package = "program_structure" }
circom-type_analysis = { git = "https://github.com/iden3/circom", package = "type_analysis" }

# MPZ
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", package = "mpz-circuits" }
bmr16-mpz = { git = "https://github.com/tkmct/mpz", package = "mpz-circuits" }
sim-circuit = { git = "https://github.com/brech1/sim-circuit" }
61 changes: 54 additions & 7 deletions src/arithmetic_circuit.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,64 @@
use crate::compiler::{ArithmeticGate, CircuitError};
use circom_program_structure::ast::ExpressionInfixOpcode;
use serde::{Deserialize, Serialize};
use serde_json::{from_str, to_string};
use sim_circuit::arithmetic_circuit::ArithmeticCircuit as SimArithmeticCircuit;
use std::{
collections::HashMap,
io::{BufRead, BufReader, BufWriter, Write},
str::FromStr,
};
use strum_macros::{Display as StrumDisplay, EnumString};

/// The supported Arithmetic gate types.
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, EnumString, StrumDisplay)]
pub enum AGateType {
AAdd,
ADiv,
AEq,
AGEq,
AGt,
ALEq,
ALt,
AMul,
ANeq,
ASub,
AXor,
APow,
AIntDiv,
AMod,
AShiftL,
AShiftR,
ABoolOr,
ABoolAnd,
ABitOr,
ABitAnd,
}

impl From<&ExpressionInfixOpcode> for AGateType {
fn from(opcode: &ExpressionInfixOpcode) -> Self {
match opcode {
ExpressionInfixOpcode::Mul => AGateType::AMul,
ExpressionInfixOpcode::Div => AGateType::ADiv,
ExpressionInfixOpcode::Add => AGateType::AAdd,
ExpressionInfixOpcode::Sub => AGateType::ASub,
ExpressionInfixOpcode::Pow => AGateType::APow,
ExpressionInfixOpcode::IntDiv => AGateType::AIntDiv,
ExpressionInfixOpcode::Mod => AGateType::AMod,
ExpressionInfixOpcode::ShiftL => AGateType::AShiftL,
ExpressionInfixOpcode::ShiftR => AGateType::AShiftR,
ExpressionInfixOpcode::LesserEq => AGateType::ALEq,
ExpressionInfixOpcode::GreaterEq => AGateType::AGEq,
ExpressionInfixOpcode::Lesser => AGateType::ALt,
ExpressionInfixOpcode::Greater => AGateType::AGt,
ExpressionInfixOpcode::Eq => AGateType::AEq,
ExpressionInfixOpcode::NotEq => AGateType::ANeq,
ExpressionInfixOpcode::BoolOr => AGateType::ABoolOr,
ExpressionInfixOpcode::BoolAnd => AGateType::ABoolAnd,
ExpressionInfixOpcode::BitOr => AGateType::ABitOr,
ExpressionInfixOpcode::BitAnd => AGateType::ABitAnd,
ExpressionInfixOpcode::BitXor => AGateType::AXor,
}
}
}

#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ArithmeticCircuit {
Expand All @@ -29,10 +81,6 @@ pub struct ConstantInfo {
}

impl ArithmeticCircuit {
pub fn to_sim(&self) -> SimArithmeticCircuit {
from_str(&to_string(self).unwrap()).unwrap()
}

pub fn get_bristol_string(&self) -> Result<String, CircuitError> {
let mut output = Vec::new();
let mut writer = BufWriter::new(&mut output);
Expand Down Expand Up @@ -216,7 +264,6 @@ impl BristolLine {
#[cfg(test)]
mod tests {
use super::*;
use crate::compiler::AGateType;
use std::io::{BufReader, Cursor};

// Helper function to create a sample ArithmeticCircuit
Expand Down
163 changes: 1 addition & 162 deletions src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,78 +3,15 @@
//! This module defines the data structures used to represent the arithmetic circuit.
use crate::{
arithmetic_circuit::{ArithmeticCircuit, CircuitInfo, ConstantInfo},
arithmetic_circuit::{AGateType, ArithmeticCircuit, CircuitInfo, ConstantInfo},
program::ProgramError,
topological_sort::topological_sort,
};
use bmr16_mpz::{
arithmetic::{
circuit::ArithmeticCircuit as MpzCircuit,
ops::{add, cmul, mul, sub},
types::CrtRepr,
ArithCircuitError as MpzCircuitError,
},
ArithmeticCircuitBuilder,
};
use circom_program_structure::ast::ExpressionInfixOpcode;
use log::debug;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use strum_macros::{Display as StrumDisplay, EnumString};
use thiserror::Error;

/// Types of gates that can be used in an arithmetic circuit.
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, EnumString, StrumDisplay)]
pub enum AGateType {
AAdd,
ADiv,
AEq,
AGEq,
AGt,
ALEq,
ALt,
AMul,
ANeq,
ASub,
AXor,
APow,
AIntDiv,
AMod,
AShiftL,
AShiftR,
ABoolOr,
ABoolAnd,
ABitOr,
ABitAnd,
}

impl From<&ExpressionInfixOpcode> for AGateType {
fn from(opcode: &ExpressionInfixOpcode) -> Self {
match opcode {
ExpressionInfixOpcode::Mul => AGateType::AMul,
ExpressionInfixOpcode::Div => AGateType::ADiv,
ExpressionInfixOpcode::Add => AGateType::AAdd,
ExpressionInfixOpcode::Sub => AGateType::ASub,
ExpressionInfixOpcode::Pow => AGateType::APow,
ExpressionInfixOpcode::IntDiv => AGateType::AIntDiv,
ExpressionInfixOpcode::Mod => AGateType::AMod,
ExpressionInfixOpcode::ShiftL => AGateType::AShiftL,
ExpressionInfixOpcode::ShiftR => AGateType::AShiftR,
ExpressionInfixOpcode::LesserEq => AGateType::ALEq,
ExpressionInfixOpcode::GreaterEq => AGateType::AGEq,
ExpressionInfixOpcode::Lesser => AGateType::ALt,
ExpressionInfixOpcode::Greater => AGateType::AGt,
ExpressionInfixOpcode::Eq => AGateType::AEq,
ExpressionInfixOpcode::NotEq => AGateType::ANeq,
ExpressionInfixOpcode::BoolOr => AGateType::ABoolOr,
ExpressionInfixOpcode::BoolAnd => AGateType::ABoolAnd,
ExpressionInfixOpcode::BitOr => AGateType::ABitOr,
ExpressionInfixOpcode::BitAnd => AGateType::ABitAnd,
ExpressionInfixOpcode::BitXor => AGateType::AXor,
}
}
}

/// Represents a signal in the circuit, with a name and an optional value.
#[derive(Debug, Serialize, Deserialize)]
pub struct Signal {
Expand Down Expand Up @@ -541,94 +478,6 @@ impl Compiler {
})
}

/// Builds an arithmetic circuit using the mpz circuit builder.
pub fn build_mpz_circuit(&self, report: &CircuitReport) -> Result<MpzCircuit, CircuitError> {
let builder = ArithmeticCircuitBuilder::new();

// Initialize CRT signals map with the circuit inputs
let mut crt_signals: HashMap<u32, CrtRepr> =
report
.inputs
.iter()
.try_fold(HashMap::new(), |mut acc, signal| {
let input = builder
.add_input::<u32>(signal.names[0].to_string())
.map_err(CircuitError::MPZCircuitError)?;
acc.insert(signal.id, input.repr);
Ok::<_, CircuitError>(acc)
})?;

// Initialize a vec for indices of gates that need processing
let mut to_process = std::collections::VecDeque::new();
to_process.extend(0..self.gates.len());

while let Some(index) = to_process.pop_front() {
let gate = &self.gates[index];

if let (Some(lh_in_repr), Some(rh_in_repr)) =
(crt_signals.get(&gate.lh_in), crt_signals.get(&gate.rh_in))
{
let result_repr = match gate.op {
AGateType::AAdd => {
add(&mut builder.state().borrow_mut(), lh_in_repr, rh_in_repr)
.map_err(|e| e.into())
}
AGateType::AMul => {
// Get the constant value from one of the signals if available
let constant_value = self
.signals
.get(&gate.lh_in)
.and_then(|signal| signal.value.map(|v| v as u64))
.or_else(|| {
self.signals
.get(&gate.rh_in)
.and_then(|signal| signal.value.map(|v| v as u64))
});

// Perform multiplication depending on whether one input is a constant
if let Some(value) = constant_value {
Ok::<_, CircuitError>(cmul(
&mut builder.state().borrow_mut(),
lh_in_repr,
value,
))
} else {
mul(&mut builder.state().borrow_mut(), lh_in_repr, rh_in_repr)
.map_err(|e| e.into())
}
}
AGateType::ASub => {
sub(&mut builder.state().borrow_mut(), lh_in_repr, rh_in_repr)
.map_err(|e| e.into())
}
_ => {
return Err(CircuitError::UnsupportedGateType(format!(
"{:?} not supported by MPZ",
gate.op
)))
}
}?;

crt_signals.insert(gate.out, result_repr);
} else {
// Not ready to process, push back for later attempt.
to_process.push_back(index);
}
}

// Add output signals
for signal in &report.outputs {
let output_repr = crt_signals
.get(&signal.id)
.ok_or_else(|| CircuitError::UnprocessedNode)?;
builder.add_output(output_repr);
}

builder
.build()
.map_err(|_| CircuitError::MPZCircuitBuilderError)
}

/// Returns a node id and increments the count.
fn get_node_id(&mut self) -> u32 {
self.node_count += 1;
Expand Down Expand Up @@ -694,10 +543,6 @@ pub enum CircuitError {
DisconnectedSignal,
#[error(transparent)]
IOError(#[from] std::io::Error),
#[error("MPZ arithmetic circuit error: {0}")]
MPZCircuitError(MpzCircuitError),
#[error("MPZ arithmetic circuit builder error")]
MPZCircuitBuilderError,
#[error(transparent)]
ParseIntError(#[from] std::num::ParseIntError),
#[error("Signal already declared")]
Expand All @@ -720,12 +565,6 @@ impl From<CircuitError> for ProgramError {
}
}

impl From<MpzCircuitError> for CircuitError {
fn from(e: MpzCircuitError) -> Self {
CircuitError::MPZCircuitError(e)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
4 changes: 3 additions & 1 deletion src/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
//!
//! Handles execution of statements and expressions for arithmetic circuit generation within a `Runtime` environment.
use crate::compiler::{AGateType, Compiler};
use crate::arithmetic_circuit::AGateType;
use crate::compiler::Compiler;
use crate::program::ProgramError;
use crate::runtime::{
generate_u32, increment_indices, u32_to_access, Context, DataAccess, DataType, NestedValue,
Expand Down Expand Up @@ -759,6 +760,7 @@ fn to_equivalent_infix(op: &ExpressionPrefixOpcode) -> (u32, ExpressionInfixOpco
ExpressionPrefixOpcode::Complement => (u32::MAX, ExpressionInfixOpcode::BitXor),
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit 5c673e2

Please sign in to comment.