diff --git a/common/src/model/mod.rs b/common/src/model/mod.rs index ba2dd3d..c98fa66 100644 --- a/common/src/model/mod.rs +++ b/common/src/model/mod.rs @@ -45,18 +45,19 @@ where // TODO: for now, we require all nodes to use the same PCS; this might change // in the future -pub struct Model { +pub struct Model { pub input_shape: Vec, pub output_shape: Vec, - pub nodes: Vec>, + pub nodes: Vec>, } -impl Model +impl Model where ST: InnerType + TryFrom, LT: InnerType + From, + F: PrimeField + Absorb + From + From, { - pub fn new(input_shape: Vec, nodes: Vec>) -> Self { + pub fn new(input_shape: Vec, nodes: Vec>) -> Self { // An empty model would cause panics later down the line e.g. when // determining the number of variables needed to commit to it. assert!(!nodes.is_empty(), "A model cannot have no nodes",); @@ -72,12 +73,11 @@ where &self.input_shape } - pub fn setup_keys( + pub fn setup_keys( &self, rng: &mut R, ) -> Result<(PCS::CommitterKey, PCS::VerifierKey), PCS::Error> where - F: PrimeField + Absorb + From + From, S: CryptographicSponge, PCS: PolynomialCommitment, S>, R: RngCore, diff --git a/common/src/model/nodes/bmm.rs b/common/src/model/nodes/bmm.rs index 47be337..9b8382d 100644 --- a/common/src/model/nodes/bmm.rs +++ b/common/src/model/nodes/bmm.rs @@ -1,6 +1,7 @@ use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; -use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; +use ark_poly::DenseMultilinearExtension; +use ark_poly_commit::{LabeledCommitment, LabeledPolynomial, PolynomialCommitment}; use ark_std::log2; use ark_sumcheck::ml_sumcheck::Proof; @@ -14,8 +15,12 @@ use super::{NodeOpsNative, NodeOpsPadded}; // TODO convention: input, bias and output are rows, the op is vec-by-mat (in that order) /// Start with 2D matrices, and Mat-by-vector multiplication only -pub struct BMMNode { - /// The row-major flattened unpadded vector of weights +pub struct BMMNode { + /// The MLE of the weight matrix as a labeled polynomial + pub weight_mle: LabeledPolynomial>, + /// The MLE of the bias vector as a labeled polynomial + pub bias_mle: LabeledPolynomial>, + /// The row-major flattened padded vector of weights weights: QArray, /// The padded weight vector pub padded_weights: QArray, @@ -103,10 +108,11 @@ pub struct BMMNodeProof< pub bias_opening_value: F, } -impl NodeOpsNative for BMMNode +impl NodeOpsNative for BMMNode where ST: InnerType, LT: InnerType + From, + F: PrimeField, { fn shape(&self) -> Vec { vec![self.dims.1] @@ -151,10 +157,11 @@ where } } -impl NodeOpsPadded for BMMNode +impl NodeOpsPadded for BMMNode where ST: InnerType + TryFrom, LT: InnerType + From, + F: PrimeField, { fn padded_shape_log(&self) -> Vec { vec![self.padded_dims_log.1] @@ -210,10 +217,11 @@ where } } -impl BMMNode +impl BMMNode where ST: InnerType, LT: InnerType, + F: PrimeField + Absorb + From + From, { pub fn new(weights: QArray, bias: QArray, input_zero_point: ST) -> Self { let dims = (weights.shape()[0], weights.shape()[1]); @@ -239,7 +247,29 @@ where .clone() .compact_resize(vec![dims.1.next_power_of_two()], LT::ZERO); + let weight_f = padded_weights + .values() + .iter() + .map(|w| F::from(*w)) + .collect(); + + // Dual of the MLE of the row-major flattening of the weight matrix + let weight_poly = + Poly::from_evaluations_vec(padded_dims_log.0 + padded_dims_log.1, weight_f); + + let weight_mle = + LabeledPolynomial::new("weight_mle".to_string(), weight_poly, Some(1), None); + + let bias_f = padded_bias.values().iter().map(|w| F::from(*w)).collect(); + + // Dual of the MLE of the bias vector + let bias_poly = Poly::from_evaluations_vec(padded_dims_log.1, bias_f); + + let bias_mle = LabeledPolynomial::new("bias_mle".to_string(), bias_poly, Some(1), None); + Self { + weight_mle, + bias_mle, weights, padded_weights, bias, diff --git a/common/src/model/nodes/mod.rs b/common/src/model/nodes/mod.rs index 6d8994b..55ba5f6 100644 --- a/common/src/model/nodes/mod.rs +++ b/common/src/model/nodes/mod.rs @@ -85,8 +85,8 @@ pub trait NodeOpsPadded: NodeOpsNative { fn padded_evaluate(&self, input: &QArray) -> QArray; } -pub enum Node { - BMM(BMMNode), +pub enum Node { + BMM(BMMNode), RequantiseBMM(RequantiseBMMNode), ReLU(ReLUNode), Reshape(ReshapeNode), @@ -130,10 +130,11 @@ where // A lot of this overlaps with the NodeOps trait and could be handled more // elegantly by simply implementing the trait -impl Node +impl Node where I: InnerType + TryFrom, O: InnerType + From, + F: PrimeField, { // Print the type of the node. This cannot be cleantly achieved by deriving // Debug diff --git a/prover/examples/common/lib.rs b/prover/examples/common/lib.rs index 6109ee4..e9d4eca 100644 --- a/prover/examples/common/lib.rs +++ b/prover/examples/common/lib.rs @@ -11,7 +11,7 @@ use ark_std::test_rng; // Auxiliary function fn unpadded_inference( raw_input: QArray, - model: &Model, + model: &Model, qinfo: (f32, u8), ) -> QArray where @@ -34,7 +34,7 @@ where // Auxiliary function fn padded_inference( raw_input: QArray, - model: &Model, + model: &Model, qinfo: (f32, u8), ) -> QArray where @@ -50,7 +50,7 @@ where let input_i8 = (quantised_input.cast::() - 128).cast::(); let output_i8 = - as ProveModel>::padded_evaluate(model, input_i8); + as ProveModel>::padded_evaluate(model, input_i8); (output_i8.cast::() + 128).cast() } @@ -58,7 +58,7 @@ where pub fn run_unpadded( input_path: &str, expected_output_path: &str, - model: &Model, + model: &Model, qinfo: (f32, u8), ) where F: PrimeField + Absorb, @@ -78,7 +78,7 @@ pub fn run_unpadded( pub fn run_padded( input_path: &str, expected_output_path: &str, - model: &Model, + model: &Model, qinfo: (f32, u8), ) where F: PrimeField + Absorb, @@ -98,7 +98,7 @@ pub fn run_padded( pub fn multi_run_unpadded( inputs_path: &str, expected_outputs_path: &str, - model: &Model, + model: &Model, qinfo: (f32, u8), ) where F: PrimeField + Absorb, @@ -121,7 +121,7 @@ pub fn multi_run_unpadded( pub fn multi_run_padded( inputs_path: &str, expected_outputs_path: &str, - model: &Model, + model: &Model, qinfo: (f32, u8), ) where F: PrimeField + Absorb, @@ -144,7 +144,7 @@ pub fn multi_run_padded( pub fn prove_inference( input_path: &str, expected_output_path: &str, - model: &Model, + model: &Model, qinfo: (f32, u8), sponge: S, output_shape: Vec, @@ -166,7 +166,7 @@ pub fn prove_inference( let mut sponge = sponge; let mut rng = test_rng(); - let (ck, _) = model.setup_keys::(&mut rng).unwrap(); + let (ck, _) = model.setup_keys::(&mut rng).unwrap(); let (node_coms, node_com_states): (Vec<_>, Vec<_>) = model.commit(&ck, None).into_iter().unzip(); @@ -194,7 +194,7 @@ pub fn prove_inference( pub fn verify_inference( input_path: &str, expected_output_path: &str, - model: &Model, + model: &Model, qinfo: (f32, u8), sponge: S, output_shape: Vec, @@ -219,7 +219,7 @@ pub fn verify_inference( let mut verification_sponge = sponge; let mut rng = test_rng(); - let (ck, vk) = model.setup_keys::(&mut rng).unwrap(); + let (ck, vk) = model.setup_keys::(&mut rng).unwrap(); let (node_coms, node_com_states): (Vec<_>, Vec<_>) = model.commit(&ck, None).into_iter().unzip(); diff --git a/prover/examples/simple_perceptron_mnist/main.rs b/prover/examples/simple_perceptron_mnist/main.rs index 6cbcdb9..00084dc 100644 --- a/prover/examples/simple_perceptron_mnist/main.rs +++ b/prover/examples/simple_perceptron_mnist/main.rs @@ -26,7 +26,7 @@ macro_rules! PATH { } // TODO this is incorrect now that we have switched to logs -fn build_simple_perceptron_mnist() -> Model +fn build_simple_perceptron_mnist() -> Model where F: PrimeField + Absorb, S: CryptographicSponge, @@ -39,7 +39,7 @@ where let w_array: QArray = QArray::read(&format!(PATH!(), "parameters/weights.json")); let b_array: QArray = QArray::read(&format!(PATH!(), "parameters/bias.json")); - let bmm: BMMNode = BMMNode::new(w_array, b_array, Z_I); + let bmm: BMMNode = BMMNode::new(w_array, b_array, Z_I); let req_bmm: RequantiseBMMNode = RequantiseBMMNode::new(OUTPUT_DIM, S_I, Z_I, S_W, Z_W, S_O, Z_O); diff --git a/prover/examples/two_layer_perceptron_mnist/main.rs b/prover/examples/two_layer_perceptron_mnist/main.rs index 26242fd..bacfe89 100644 --- a/prover/examples/two_layer_perceptron_mnist/main.rs +++ b/prover/examples/two_layer_perceptron_mnist/main.rs @@ -27,7 +27,7 @@ macro_rules! PATH { }; } -fn build_two_layer_perceptron_mnist() -> Model +fn build_two_layer_perceptron_mnist() -> Model where F: PrimeField + Absorb, S: CryptographicSponge, @@ -42,14 +42,14 @@ where let w2_array: QArray = QArray::read(&format!(PATH!(), "parameters/weights_2.json")); let b2_array: QArray = QArray::read(&format!(PATH!(), "parameters/bias_2.json")); - let bmm_1: BMMNode = BMMNode::new(w1_array, b1_array, Z_1_I); + let bmm_1: BMMNode = BMMNode::new(w1_array, b1_array, Z_1_I); let req_bmm_1: RequantiseBMMNode = RequantiseBMMNode::new(INTER_DIM, S_1_I, Z_1_I, S_1_W, Z_1_W, S_1_O, Z_1_O); let relu: ReLUNode = ReLUNode::new(28, Z_1_O); - let bmm_2: BMMNode = BMMNode::new(w2_array, b2_array, Z_2_I); + let bmm_2: BMMNode = BMMNode::new(w2_array, b2_array, Z_2_I); let req_bmm_2: RequantiseBMMNode = RequantiseBMMNode::new(OUTPUT_DIM, S_2_I, Z_2_I, S_2_W, Z_2_W, S_2_O, Z_2_O); diff --git a/prover/src/lib.rs b/prover/src/lib.rs index e807cc5..3cc8f7b 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -44,7 +44,7 @@ where ) -> (NodeCommitment, NodeCommitmentState); } -impl NodeOpsProve for Node +impl NodeOpsProve for Node where F: PrimeField + Absorb + From + From + From, S: CryptographicSponge, diff --git a/prover/src/model.rs b/prover/src/model.rs index 5a69e6e..cfd11ac 100644 --- a/prover/src/model.rs +++ b/prover/src/model.rs @@ -33,7 +33,7 @@ where ) -> Vec<(NodeCommitment, NodeCommitmentState)>; } -impl ProveModel for Model +impl ProveModel for Model where F: PrimeField + Absorb + From + From, S: CryptographicSponge, diff --git a/prover/src/nodes/bmm.rs b/prover/src/nodes/bmm.rs index bc141ef..a8b126e 100644 --- a/prover/src/nodes/bmm.rs +++ b/prover/src/nodes/bmm.rs @@ -3,18 +3,18 @@ use std::rc::Rc; use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; use ark_poly::{MultilinearExtension, Polynomial}; -use ark_poly_commit::{LabeledCommitment, LabeledPolynomial, PolynomialCommitment}; +use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; use ark_std::rand::RngCore; use ark_sumcheck::ml_sumcheck::{protocol::ListOfProductsOfPolynomials, MLSumcheck}; use hcs_common::{ BMMNode, BMMNodeCommitment, BMMNodeCommitmentState, BMMNodeProof, InnerType, LabeledPoly, - NodeCommitment, NodeCommitmentState, NodeOpsPadded, NodeProof, Poly, + NodeCommitment, NodeCommitmentState, NodeProof, Poly, }; use crate::NodeOpsProve; -impl NodeOpsProve for BMMNode +impl NodeOpsProve for BMMNode where F: PrimeField + Absorb + From + From, S: CryptographicSponge, @@ -65,33 +65,12 @@ where input.polynomial().iter().map(|x| *x - i_z_p_f).collect(), ); - // TODO consider whether this can be done once and stored - let weights_f = self - .padded_weights - .values() - .iter() - .map(|w| F::from(*w)) - .collect(); - - // Dual of the MLE of the row-major flattening of the weight matrix - let weight_mle = Poly::from_evaluations_vec(self.com_num_vars(), weights_f); - - // TODO consider whether this can be done once and stored - let bias_f = self - .padded_bias - .values() - .iter() - .map(|w| F::from(*w)) - .collect(); - // Dual of the MLE of the bias vector - let bias_mle = Poly::from_evaluations_vec(self.padded_dims_log.1, bias_f); - - let bias_opening_value = bias_mle.evaluate(&r); + let bias_opening_value = self.bias_mle.evaluate(&r); let output_opening_value = output.evaluate(&r); // Constructing the sumcheck polynomial // g(x) = (input - zero_point)^(x) * W^(r, x), - let bound_weight_mle = weight_mle.fix_variables(&r); + let bound_weight_mle = self.weight_mle.polynomial().fix_variables(&r); let mut g = ListOfProductsOfPolynomials::new(self.padded_dims_log.0); // TODO we are cloning the input here, can we do better? @@ -138,12 +117,7 @@ where let weight_opening_proof = PCS::open( ck, - [&LabeledPolynomial::new( - "weight_mle".to_string(), - weight_mle, - Some(1), - None, - )], + [&self.weight_mle], [weight_com], &r.clone() .into_iter() @@ -159,10 +133,7 @@ where // with a single call to PCS::open let output_bias_opening_proof = PCS::open( ck, - [ - output, - &LabeledPolynomial::new("bias_mle".to_string(), bias_mle, Some(1), None), - ], + [output, &self.bias_mle], [output_com, bias_com], &r, sponge, @@ -188,38 +159,7 @@ where ck: &PCS::CommitterKey, rng: Option<&mut dyn RngCore>, ) -> (NodeCommitment, NodeCommitmentState) { - // TODO should we separate the associated commitment type into one with state and one without? - let padded_weights_f: Vec = self - .padded_weights - .values() - .iter() - .map(|w| F::from(*w)) - .collect(); - - // TODO part of this code is duplicated in prove, another hint that this should probs - // be stored - let weight_poly = LabeledPolynomial::new( - "weight_poly".to_string(), - Poly::from_evaluations_vec(self.com_num_vars(), padded_weights_f), - Some(1), - None, - ); - - let padded_bias_f: Vec = self - .padded_bias - .values() - .iter() - .map(|b| F::from(*b)) - .collect(); - - let bias_poly = LabeledPolynomial::new( - "bias_poly".to_string(), - Poly::from_evaluations_vec(self.padded_dims_log.1, padded_bias_f), - Some(1), - None, - ); - - let coms = PCS::commit(ck, vec![&weight_poly, &bias_poly], rng).unwrap(); + let coms = PCS::commit(ck, vec![&self.weight_mle, &self.bias_mle], rng).unwrap(); ( NodeCommitment::BMM(BMMNodeCommitment { diff --git a/verifier/src/lib.rs b/verifier/src/lib.rs index 8b711c4..306ab90 100644 --- a/verifier/src/lib.rs +++ b/verifier/src/lib.rs @@ -26,7 +26,7 @@ where ) -> bool; } -impl NodeOpsVerify for Node +impl NodeOpsVerify for Node where F: PrimeField + Absorb + From, S: CryptographicSponge, @@ -47,7 +47,9 @@ where } } -fn node_as_node_ops_snark(node: &Node) -> &dyn NodeOpsVerify +fn node_as_node_ops_snark( + node: &Node, +) -> &dyn NodeOpsVerify where F: PrimeField + Absorb + From, S: CryptographicSponge, diff --git a/verifier/src/model.rs b/verifier/src/model.rs index 45ea304..24f026d 100644 --- a/verifier/src/model.rs +++ b/verifier/src/model.rs @@ -22,7 +22,7 @@ where ) -> bool; } -impl VerifyModel for Model +impl VerifyModel for Model where F: PrimeField + Absorb + From + From, S: CryptographicSponge, diff --git a/verifier/src/nodes/bmm.rs b/verifier/src/nodes/bmm.rs index 79f2cb1..0be65c7 100644 --- a/verifier/src/nodes/bmm.rs +++ b/verifier/src/nodes/bmm.rs @@ -11,7 +11,7 @@ use hcs_common::{ use crate::NodeOpsVerify; -impl NodeOpsVerify for BMMNode +impl NodeOpsVerify for BMMNode where F: PrimeField + Absorb + From, S: CryptographicSponge,