Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions common/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ where

// TODO: for now, we require all nodes to use the same PCS; this might change
// in the future
pub struct Model<ST, LT> {
pub struct Model<ST, LT, F: PrimeField> {
pub input_shape: Vec<usize>,
pub output_shape: Vec<usize>,
pub nodes: Vec<Node<ST, LT>>,
pub nodes: Vec<Node<ST, LT, F>>,
}

impl<ST, LT> Model<ST, LT>
impl<ST, LT, F> Model<ST, LT, F>
where
ST: InnerType + TryFrom<LT>,
LT: InnerType + From<ST>,
F: PrimeField + Absorb + From<ST> + From<LT>,
{
pub fn new(input_shape: Vec<usize>, nodes: Vec<Node<ST, LT>>) -> Self {
pub fn new(input_shape: Vec<usize>, nodes: Vec<Node<ST, LT, F>>) -> Self {
// An empty model would cause panics later down the line e.g. when
// determining the number of variables needed to commit to it.
assert!(!nodes.is_empty(), "A model cannot have no nodes",);
Expand All @@ -72,12 +73,11 @@ where
&self.input_shape
}

pub fn setup_keys<F, S, PCS, R>(
pub fn setup_keys<S, PCS, R>(
&self,
rng: &mut R,
) -> Result<(PCS::CommitterKey, PCS::VerifierKey), PCS::Error>
where
F: PrimeField + Absorb + From<ST> + From<LT>,
S: CryptographicSponge,
PCS: PolynomialCommitment<F, Poly<F>, S>,
R: RngCore,
Expand Down
42 changes: 36 additions & 6 deletions common/src/model/nodes/bmm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge};
use ark_ff::PrimeField;
use ark_poly_commit::{LabeledCommitment, PolynomialCommitment};
use ark_poly::DenseMultilinearExtension;
use ark_poly_commit::{LabeledCommitment, LabeledPolynomial, PolynomialCommitment};
use ark_std::log2;

use ark_sumcheck::ml_sumcheck::Proof;
Expand All @@ -14,8 +15,12 @@ use super::{NodeOpsNative, NodeOpsPadded};
// TODO convention: input, bias and output are rows, the op is vec-by-mat (in that order)

/// Start with 2D matrices, and Mat-by-vector multiplication only
pub struct BMMNode<ST, LT> {
/// The row-major flattened unpadded vector of weights
pub struct BMMNode<ST, LT, F: PrimeField> {
/// The MLE of the weight matrix as a labeled polynomial
pub weight_mle: LabeledPolynomial<F, DenseMultilinearExtension<F>>,
/// The MLE of the bias vector as a labeled polynomial
pub bias_mle: LabeledPolynomial<F, DenseMultilinearExtension<F>>,
/// The row-major flattened padded vector of weights
weights: QArray<ST>,
/// The padded weight vector
pub padded_weights: QArray<ST>,
Expand Down Expand Up @@ -103,10 +108,11 @@ pub struct BMMNodeProof<
pub bias_opening_value: F,
}

impl<ST, LT> NodeOpsNative<ST, LT> for BMMNode<ST, LT>
impl<ST, LT, F> NodeOpsNative<ST, LT> for BMMNode<ST, LT, F>
where
ST: InnerType,
LT: InnerType + From<ST>,
F: PrimeField,
{
fn shape(&self) -> Vec<usize> {
vec![self.dims.1]
Expand Down Expand Up @@ -151,10 +157,11 @@ where
}
}

impl<ST, LT> NodeOpsPadded<ST, LT> for BMMNode<ST, LT>
impl<ST, LT, F> NodeOpsPadded<ST, LT> for BMMNode<ST, LT, F>
where
ST: InnerType + TryFrom<LT>,
LT: InnerType + From<ST>,
F: PrimeField,
{
fn padded_shape_log(&self) -> Vec<usize> {
vec![self.padded_dims_log.1]
Expand Down Expand Up @@ -210,10 +217,11 @@ where
}
}

impl<ST, LT> BMMNode<ST, LT>
impl<ST, LT, F> BMMNode<ST, LT, F>
where
ST: InnerType,
LT: InnerType,
F: PrimeField + Absorb + From<ST> + From<LT>,
{
pub fn new(weights: QArray<ST>, bias: QArray<LT>, input_zero_point: ST) -> Self {
let dims = (weights.shape()[0], weights.shape()[1]);
Expand All @@ -239,7 +247,29 @@ where
.clone()
.compact_resize(vec![dims.1.next_power_of_two()], LT::ZERO);

let weight_f = padded_weights
.values()
.iter()
.map(|w| F::from(*w))
.collect();

// Dual of the MLE of the row-major flattening of the weight matrix
let weight_poly =
Poly::from_evaluations_vec(padded_dims_log.0 + padded_dims_log.1, weight_f);

let weight_mle =
LabeledPolynomial::new("weight_mle".to_string(), weight_poly, Some(1), None);

let bias_f = padded_bias.values().iter().map(|w| F::from(*w)).collect();

// Dual of the MLE of the bias vector
let bias_poly = Poly::from_evaluations_vec(padded_dims_log.1, bias_f);

let bias_mle = LabeledPolynomial::new("bias_mle".to_string(), bias_poly, Some(1), None);

Self {
weight_mle,
bias_mle,
weights,
padded_weights,
bias,
Expand Down
7 changes: 4 additions & 3 deletions common/src/model/nodes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ pub trait NodeOpsPadded<I, O>: NodeOpsNative<I, O> {
fn padded_evaluate(&self, input: &QArray<I>) -> QArray<O>;
}

pub enum Node<ST, LT> {
BMM(BMMNode<ST, LT>),
pub enum Node<ST, LT, F: PrimeField> {
BMM(BMMNode<ST, LT, F>),
RequantiseBMM(RequantiseBMMNode<ST>),
ReLU(ReLUNode<ST>),
Reshape(ReshapeNode),
Expand Down Expand Up @@ -130,10 +130,11 @@ where

// A lot of this overlaps with the NodeOps trait and could be handled more
// elegantly by simply implementing the trait
impl<I, O> Node<I, O>
impl<I, O, F> Node<I, O, F>
where
I: InnerType + TryFrom<O>,
O: InnerType + From<I>,
F: PrimeField,
{
// Print the type of the node. This cannot be cleantly achieved by deriving
// Debug
Expand Down
22 changes: 11 additions & 11 deletions prover/examples/common/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use ark_std::test_rng;
// Auxiliary function
fn unpadded_inference<F, S, PCS>(
raw_input: QArray<f32>,
model: &Model<i8, i32>,
model: &Model<i8, i32, F>,
qinfo: (f32, u8),
) -> QArray<u8>
where
Expand All @@ -34,7 +34,7 @@ where
// Auxiliary function
fn padded_inference<F, S, PCS>(
raw_input: QArray<f32>,
model: &Model<i8, i32>,
model: &Model<i8, i32, F>,
qinfo: (f32, u8),
) -> QArray<u8>
where
Expand All @@ -50,15 +50,15 @@ where
let input_i8 = (quantised_input.cast::<i32>() - 128).cast::<i8>();

let output_i8 =
<Model<i8, i32> as ProveModel<F, S, PCS, i8, i32>>::padded_evaluate(model, input_i8);
<Model<i8, i32, F> as ProveModel<F, S, PCS, i8, i32>>::padded_evaluate(model, input_i8);

(output_i8.cast::<i32>() + 128).cast()
}

pub fn run_unpadded<F, S, PCS>(
input_path: &str,
expected_output_path: &str,
model: &Model<i8, i32>,
model: &Model<i8, i32, F>,
qinfo: (f32, u8),
) where
F: PrimeField + Absorb,
Expand All @@ -78,7 +78,7 @@ pub fn run_unpadded<F, S, PCS>(
pub fn run_padded<F, S, PCS>(
input_path: &str,
expected_output_path: &str,
model: &Model<i8, i32>,
model: &Model<i8, i32, F>,
qinfo: (f32, u8),
) where
F: PrimeField + Absorb,
Expand All @@ -98,7 +98,7 @@ pub fn run_padded<F, S, PCS>(
pub fn multi_run_unpadded<F, S, PCS>(
inputs_path: &str,
expected_outputs_path: &str,
model: &Model<i8, i32>,
model: &Model<i8, i32, F>,
qinfo: (f32, u8),
) where
F: PrimeField + Absorb,
Expand All @@ -121,7 +121,7 @@ pub fn multi_run_unpadded<F, S, PCS>(
pub fn multi_run_padded<F, S, PCS>(
inputs_path: &str,
expected_outputs_path: &str,
model: &Model<i8, i32>,
model: &Model<i8, i32, F>,
qinfo: (f32, u8),
) where
F: PrimeField + Absorb,
Expand All @@ -144,7 +144,7 @@ pub fn multi_run_padded<F, S, PCS>(
pub fn prove_inference<F, S, PCS>(
input_path: &str,
expected_output_path: &str,
model: &Model<i8, i32>,
model: &Model<i8, i32, F>,
qinfo: (f32, u8),
sponge: S,
output_shape: Vec<usize>,
Expand All @@ -166,7 +166,7 @@ pub fn prove_inference<F, S, PCS>(
let mut sponge = sponge;

let mut rng = test_rng();
let (ck, _) = model.setup_keys::<F, S, PCS, _>(&mut rng).unwrap();
let (ck, _) = model.setup_keys::<S, PCS, _>(&mut rng).unwrap();

let (node_coms, node_com_states): (Vec<_>, Vec<_>) =
model.commit(&ck, None).into_iter().unzip();
Expand Down Expand Up @@ -194,7 +194,7 @@ pub fn prove_inference<F, S, PCS>(
pub fn verify_inference<F, S, PCS>(
input_path: &str,
expected_output_path: &str,
model: &Model<i8, i32>,
model: &Model<i8, i32, F>,
qinfo: (f32, u8),
sponge: S,
output_shape: Vec<usize>,
Expand All @@ -219,7 +219,7 @@ pub fn verify_inference<F, S, PCS>(
let mut verification_sponge = sponge;

let mut rng = test_rng();
let (ck, vk) = model.setup_keys::<F, S, PCS, _>(&mut rng).unwrap();
let (ck, vk) = model.setup_keys::<S, PCS, _>(&mut rng).unwrap();

let (node_coms, node_com_states): (Vec<_>, Vec<_>) =
model.commit(&ck, None).into_iter().unzip();
Expand Down
4 changes: 2 additions & 2 deletions prover/examples/simple_perceptron_mnist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ macro_rules! PATH {
}

// TODO this is incorrect now that we have switched to logs
fn build_simple_perceptron_mnist<F, S, PCS>() -> Model<i8, i32>
fn build_simple_perceptron_mnist<F, S, PCS>() -> Model<i8, i32, F>
where
F: PrimeField + Absorb,
S: CryptographicSponge,
Expand All @@ -39,7 +39,7 @@ where
let w_array: QArray<i8> = QArray::read(&format!(PATH!(), "parameters/weights.json"));
let b_array: QArray<i32> = QArray::read(&format!(PATH!(), "parameters/bias.json"));

let bmm: BMMNode<i8, i32> = BMMNode::new(w_array, b_array, Z_I);
let bmm: BMMNode<i8, i32, F> = BMMNode::new(w_array, b_array, Z_I);

let req_bmm: RequantiseBMMNode<i8> =
RequantiseBMMNode::new(OUTPUT_DIM, S_I, Z_I, S_W, Z_W, S_O, Z_O);
Expand Down
6 changes: 3 additions & 3 deletions prover/examples/two_layer_perceptron_mnist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ macro_rules! PATH {
};
}

fn build_two_layer_perceptron_mnist<F, S, PCS>() -> Model<i8, i32>
fn build_two_layer_perceptron_mnist<F, S, PCS>() -> Model<i8, i32, F>
where
F: PrimeField + Absorb,
S: CryptographicSponge,
Expand All @@ -42,14 +42,14 @@ where
let w2_array: QArray<i8> = QArray::read(&format!(PATH!(), "parameters/weights_2.json"));
let b2_array: QArray<i32> = QArray::read(&format!(PATH!(), "parameters/bias_2.json"));

let bmm_1: BMMNode<i8, i32> = BMMNode::new(w1_array, b1_array, Z_1_I);
let bmm_1: BMMNode<i8, i32, F> = BMMNode::new(w1_array, b1_array, Z_1_I);

let req_bmm_1: RequantiseBMMNode<i8> =
RequantiseBMMNode::new(INTER_DIM, S_1_I, Z_1_I, S_1_W, Z_1_W, S_1_O, Z_1_O);

let relu: ReLUNode<i8> = ReLUNode::new(28, Z_1_O);

let bmm_2: BMMNode<i8, i32> = BMMNode::new(w2_array, b2_array, Z_2_I);
let bmm_2: BMMNode<i8, i32, F> = BMMNode::new(w2_array, b2_array, Z_2_I);

let req_bmm_2: RequantiseBMMNode<i8> =
RequantiseBMMNode::new(OUTPUT_DIM, S_2_I, Z_2_I, S_2_W, Z_2_W, S_2_O, Z_2_O);
Expand Down
2 changes: 1 addition & 1 deletion prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ where
) -> (NodeCommitment<F, S, PCS>, NodeCommitmentState<F, S, PCS>);
}

impl<F, S, PCS, I, O> NodeOpsProve<F, S, PCS, I, O> for Node<I, O>
impl<F, S, PCS, I, O> NodeOpsProve<F, S, PCS, I, O> for Node<I, O, F>
where
F: PrimeField + Absorb + From<I> + From<O> + From<O>,
S: CryptographicSponge,
Expand Down
2 changes: 1 addition & 1 deletion prover/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ where
) -> Vec<(NodeCommitment<F, S, PCS>, NodeCommitmentState<F, S, PCS>)>;
}

impl<F, S, PCS, ST, LT> ProveModel<F, S, PCS, ST, LT> for Model<ST, LT>
impl<F, S, PCS, ST, LT> ProveModel<F, S, PCS, ST, LT> for Model<ST, LT, F>
where
F: PrimeField + Absorb + From<ST> + From<LT>,
S: CryptographicSponge,
Expand Down
Loading