Skip to content

Commit

Permalink
chore: add types to CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
curryrasul committed Sep 11, 2024
1 parent bf5a35e commit 21cb02a
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 11 deletions.
31 changes: 29 additions & 2 deletions src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
use clap::Parser;
use std::path::{Path, PathBuf};

use clap::{Parser, ValueEnum};
use serde::{Deserialize, Serialize};

#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ValueType {
#[serde(rename = "sint")]
#[default]
Sint,
#[serde(rename = "sfloat")]
Sfloat,
}

#[derive(Parser)]
#[clap(name = "Arithmetic Circuits Compiler")]
#[command(disable_help_subcommand = true)]
Expand All @@ -23,6 +35,15 @@ pub struct Args {
)]
pub output: PathBuf,

#[arg(
short,
long,
value_enum,
help = "Type that'll be used for values in MPC backend",
default_value_t = ValueType::Sint,
)]
pub value_type: ValueType,

#[arg(
long,
help = "Optional: Convert to a boolean circuit by using integers with this number of bits",
Expand All @@ -32,10 +53,16 @@ pub struct Args {
}

impl Args {
pub fn new(input: PathBuf, output: PathBuf, boolify_width: Option<usize>) -> Self {
pub fn new(
input: PathBuf,
output: PathBuf,
value_type: ValueType,
boolify_width: Option<usize>,
) -> Self {
Self {
input,
output,
value_type,
boolify_width,
}
}
Expand Down
20 changes: 18 additions & 2 deletions src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
//!
//! This module defines the data structures used to represent the arithmetic circuit.
use crate::{a_gate_type::AGateType, program::ProgramError, topological_sort::topological_sort};
use crate::{
a_gate_type::AGateType, cli::ValueType, program::ProgramError,
topological_sort::topological_sort,
};
use bristol_circuit::{BristolCircuit, CircuitInfo, ConstantInfo, Gate};
use log::debug;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -108,6 +111,7 @@ pub struct Compiler {
signals: HashMap<u32, Signal>,
nodes: HashMap<u32, Node>,
gates: Vec<ArithmeticGate>,
value_type: ValueType,
}

impl Compiler {
Expand All @@ -119,6 +123,7 @@ impl Compiler {
signals: HashMap::new(),
nodes: HashMap::new(),
gates: Vec::new(),
value_type: Default::default(),
}
}

Expand Down Expand Up @@ -272,6 +277,12 @@ impl Compiler {
Ok(())
}

pub fn update_type(&mut self, value_type: ValueType) -> Result<(), CircuitError> {
self.value_type = value_type;

Ok(())
}

/// Generates a circuit report with input and output signals information.
pub fn generate_circuit_report(&self) -> Result<CircuitReport, CircuitError> {
// Split input and output nodes
Expand Down Expand Up @@ -300,7 +311,11 @@ impl Compiler {
let inputs = self.generate_signal_reports(&input_nodes);
let outputs = self.generate_signal_reports(&output_nodes);

Ok(CircuitReport { inputs, outputs })
Ok(CircuitReport {
inputs,
outputs,
value_type: self.value_type,
})
}

pub fn build_circuit(&self) -> Result<BristolCircuit, CircuitError> {
Expand Down Expand Up @@ -521,6 +536,7 @@ impl Compiler {
pub struct CircuitReport {
inputs: Vec<SignalReport>,
outputs: Vec<SignalReport>,
value_type: ValueType,
}

/// A single node report, with a list of signal names and an optional value.
Expand Down
2 changes: 2 additions & 0 deletions src/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ pub fn compile(args: &Args) -> Result<Compiler, ProgramError> {
_ => return Err(ProgramError::MainExpressionNotACall),
}

compiler.update_type(args.value_type)?;

Ok(compiler)
}

Expand Down
16 changes: 9 additions & 7 deletions tests/integration.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(clippy::upper_case_acronyms)]

use bristol_circuit::BristolCircuit;
use circom_2_arithc::a_gate_type::AGateType;
use circom_2_arithc::{a_gate_type::AGateType, cli::ValueType};
use sim_circuit::{
circuit::{CircuitBuilder, CircuitMemory, GenericCircuit, GenericCircuitExecutor},
model::{Component, Executable, Memory},
Expand Down Expand Up @@ -139,16 +139,16 @@ impl ArithmeticCircuit {
// Get circuit inputs
let inputs = circuit.info.input_name_to_wire_index;
for (label, index) in inputs {
label_to_index.insert(label, index as usize);
input_indices.push(index as usize);
label_to_index.insert(label, index);
input_indices.push(index);
}

// Get circuit constants
let mut constants: HashMap<usize, u32> = HashMap::new();
for (_, constant_info) in circuit.info.constants {
input_indices.push(constant_info.wire_index as usize);
input_indices.push(constant_info.wire_index);
constants.insert(
constant_info.wire_index as usize,
constant_info.wire_index,
constant_info.value.parse().unwrap(),
);
}
Expand All @@ -157,7 +157,6 @@ impl ArithmeticCircuit {
let output_map = circuit.info.output_name_to_wire_index;
let mut output_indices = vec![];
for (label, index) in output_map {
let index = index as usize;
label_to_index.insert(label.clone(), index);
outputs.push(label);
output_indices.push(index);
Expand Down Expand Up @@ -260,7 +259,7 @@ mod integration_tests {
inputs: &[(&str, u32)],
expected_outputs: &[(&str, u32)],
) {
let compiler_input = Args::new(circuit_path.into(), "./".into(), None);
let compiler_input = Args::new(circuit_path.into(), "./".into(), ValueType::Sint, None);
let circuit = compile(&compiler_input).unwrap().build_circuit().unwrap();
let arithmetic_circuit = ArithmeticCircuit::new_from_bristol(circuit).unwrap();

Expand Down Expand Up @@ -379,6 +378,7 @@ mod integration_tests {
let compiler_input = Args::new(
"tests/circuits/integration/indexOutOfBounds.circom".into(),
"./".into(),
ValueType::Sint,
None,
);
let circuit = compile(&compiler_input);
Expand All @@ -395,6 +395,7 @@ mod integration_tests {
let compiler_input = Args::new(
"tests/circuits/integration/constantSum.circom".into(),
"./".into(),
ValueType::Sint,
None,
);
let circuit_res = compile(&compiler_input);
Expand All @@ -418,6 +419,7 @@ mod integration_tests {
let compiler_input = Args::new(
"tests/circuits/integration/directOutput.circom".into(),
"./".into(),
ValueType::Sint,
None,
);
let circuit_res = compile(&compiler_input);
Expand Down

0 comments on commit 21cb02a

Please sign in to comment.