Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions src/model/examples/simple_perceptron_mnist/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use crate::{
model::{
nodes::{loose_fc::LooseFCNode, reshape::ReshapeNode, Node},
nodes::{loose_fc::LooseFCNode, reshape::ReshapeNode, Node, NodeType},
qarray::QArray,
Model, Poly,
}, pcs_types::Brakedown, quantization::{quantise_f32_u8_nne, QSmallType}
},
pcs_types::Brakedown,
quantization::{quantise_f32_u8_nne, QSmallType},
};

use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge};
use ark_poly_commit::PolynomialCommitment;
use ark_bn254::Fr;
use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge};
use ark_ff::PrimeField;
use ark_poly_commit::PolynomialCommitment;

mod input;
mod parameters;
Expand All @@ -27,10 +29,10 @@ where
S: CryptographicSponge,
PCS: PolynomialCommitment<F, Poly<F>, S>,
{

let flat_dim = INPUT_DIMS.iter().product();

let reshape: ReshapeNode<F, S, PCS> = ReshapeNode::new(INPUT_DIMS.to_vec(), vec![flat_dim]);
let reshape: ReshapeNode<F, S, PCS> =
ReshapeNode::new(INPUT_DIMS.to_vec(), vec![flat_dim], NodeType::Input);

let lfc: LooseFCNode<F, S, PCS> = LooseFCNode::new(
WEIGHTS.to_vec(),
Expand All @@ -43,9 +45,13 @@ where
Z_W,
S_O,
Z_O,
NodeType::Output,
);

Model::new(INPUT_DIMS.to_vec(), vec![Node::Reshape(reshape), Node::LooseFC(lfc)])
Model::new(
INPUT_DIMS.to_vec(),
vec![Node::Reshape(reshape), Node::LooseFC(lfc)],
)
}

#[test]
Expand Down
35 changes: 22 additions & 13 deletions src/model/examples/two_layer_perceptron_mnist/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
use crate::{
model::{
nodes::{fc::FCNode, loose_fc::LooseFCNode, relu::ReLUNode, reshape::ReshapeNode, Node},
nodes::{
fc::FCNode, loose_fc::LooseFCNode, relu::ReLUNode, reshape::ReshapeNode, Node, NodeType,
},
qarray::QArray,
Model, Poly,
}, pcs_types::Brakedown, quantization::{quantise_f32_u8_nne, QSmallType}
},
pcs_types::Brakedown,
quantization::{quantise_f32_u8_nne, QSmallType},
};

use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge};
use ark_poly_commit::PolynomialCommitment;
use ark_bn254::Fr;
use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge};
use ark_ff::PrimeField;
use ark_poly_commit::PolynomialCommitment;

mod input;
mod parameters;
Expand All @@ -28,10 +32,10 @@ where
S: CryptographicSponge,
PCS: PolynomialCommitment<F, Poly<F>, S>,
{

let flat_dim = INPUT_DIMS.iter().product();

let reshape: ReshapeNode<F, S, PCS> = ReshapeNode::new(INPUT_DIMS.to_vec(), vec![flat_dim]);
let reshape: ReshapeNode<F, S, PCS> =
ReshapeNode::new(INPUT_DIMS.to_vec(), vec![flat_dim], NodeType::Input);

let lfc: LooseFCNode<F, S, PCS> = LooseFCNode::new(
WEIGHTS_1.to_vec(),
Expand All @@ -44,9 +48,10 @@ where
Z_1_W,
S_1_O,
Z_1_O,
NodeType::default(),
);

let relu: ReLUNode<F, S, PCS> = ReLUNode::new(28, Z_1_O);
let relu: ReLUNode<F, S, PCS> = ReLUNode::new(28, Z_1_O, NodeType::default());

let fc2: FCNode<F, S, PCS> = FCNode::new(
WEIGHTS_2.to_vec(),
Expand All @@ -58,14 +63,18 @@ where
Z_2_W,
S_2_O,
Z_2_O,
NodeType::Output,
);

Model::new(INPUT_DIMS.to_vec(), vec![
Node::Reshape(reshape),
Node::LooseFC(lfc),
Node::ReLU(relu),
Node::FC(fc2),
])
Model::new(
INPUT_DIMS.to_vec(),
vec![
Node::Reshape(reshape),
Node::LooseFC(lfc),
Node::ReLU(relu),
Node::FC(fc2),
],
)
}

#[test]
Expand Down
6 changes: 5 additions & 1 deletion src/model/nodes/fc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::quantization::{
};
use crate::{Commitment, CommitmentState};

use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK, NodeProof};
use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK, NodeProof, NodeType};

// TODO convention: input, bias and output are rows, the op is vec-by-mat (in that order)

Expand All @@ -34,6 +34,8 @@ pub(crate) struct FCNode<F, S, PCS> {
/// Quantisation info used for both result computation and requantisation
q_info: FCQInfo,

node_type: NodeType,

phantom: PhantomData<(F, S, PCS)>,
}

Expand Down Expand Up @@ -253,6 +255,7 @@ where
z_w: QSmallType,
s_o: QScaleType,
z_o: QSmallType,
node_type: NodeType,
) -> Self {
assert_eq!(
weights.len(),
Expand Down Expand Up @@ -308,6 +311,7 @@ where
dims,
padded_dims_log,
q_info,
node_type,
phantom: PhantomData,
}
}
Expand Down
6 changes: 5 additions & 1 deletion src/model/nodes/loose_fc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::quantization::{
};
use crate::{Commitment, CommitmentState};

use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK};
use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK, NodeType};

// TODO convention: input, bias and output are rows, the op is vec-by-mat (in that order)

Expand All @@ -35,6 +35,8 @@ pub(crate) struct LooseFCNode<F, S, PCS> {
/// Quantisation info used for both result computation and requantisation
q_info: FCQInfo,

node_type: NodeType,

phantom: PhantomData<(F, S, PCS)>,
}

Expand Down Expand Up @@ -256,6 +258,7 @@ where
z_w: QSmallType,
s_o: QScaleType,
z_o: QSmallType,
node_type: NodeType,
) -> Self {
assert_eq!(
weights.len(),
Expand Down Expand Up @@ -317,6 +320,7 @@ where
dims,
padded_dims_log,
q_info,
node_type,
phantom: PhantomData,
}
}
Expand Down
10 changes: 7 additions & 3 deletions src/model/nodes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ pub(crate) mod reshape;

// mod parser;

// TODO: batched methods (e.g. for multiple evaluations)
// TODO: issue: missing info about size of the next output? Or reduplicate it?
// TODO way to handle generics more elegantly? or perhaps polynomials can be made ML directly?
#[derive(Default, PartialEq, Eq, Debug)]
pub enum NodeType {
Input,
Output,
#[default]
Middle,
}

/// A node of the model including its transition function to the next node(s).
/// It stores information about the transition (such as a matrix and bias, if
Expand Down
6 changes: 4 additions & 2 deletions src/model/nodes/relu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::model::qarray::QArray;
use crate::model::Poly;
use crate::quantization::QSmallType;

use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK};
use super::{Node, NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK, NodeType};

// Rectified linear unit node performing x |-> max(0, x).
pub(crate) struct ReLUNode<F, S, PCS>
Expand All @@ -22,6 +22,7 @@ where
num_units: usize,
log_num_units: usize,
zero_point: QSmallType,
node_type: NodeType,
phantom: PhantomData<(F, S, PCS)>,
}

Expand Down Expand Up @@ -89,13 +90,14 @@ where
S: CryptographicSponge,
PCS: PolynomialCommitment<F, Poly<F>, S>,
{
pub(crate) fn new(num_units: usize, zero_point: QSmallType) -> Self {
pub(crate) fn new(num_units: usize, zero_point: QSmallType, node_type: NodeType) -> Self {
let log_num_units = log2(num_units.next_power_of_two()) as usize;

Self {
num_units,
log_num_units,
zero_point,
node_type,
phantom: PhantomData,
}
}
Expand Down
16 changes: 14 additions & 2 deletions src/model/nodes/reshape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::model::qarray::QArray;
use crate::model::Poly;
use crate::quantization::QSmallType;

use super::{NodeOps, NodeOpsSNARK, NodeProof};
use super::{NodeOps, NodeOpsSNARK, NodeProof, NodeType};

pub(crate) struct ReshapeNode<F, S, PCS>
where
Expand All @@ -22,6 +22,7 @@ where
output_shape: Vec<usize>,
padded_input_shape_log: Vec<usize>,
padded_output_shape_log: Vec<usize>,
node_type: NodeType,
phantom: PhantomData<(F, S, PCS)>,
}

Expand Down Expand Up @@ -122,7 +123,17 @@ where
S: CryptographicSponge,
PCS: PolynomialCommitment<F, Poly<F>, S>,
{
pub(crate) fn new(input_shape: Vec<usize>, output_shape: Vec<usize>) -> Self {
pub(crate) fn new(
input_shape: Vec<usize>,
output_shape: Vec<usize>,
node_type: NodeType,
) -> Self {
assert_eq!(
node_type,
NodeType::Input,
"Currently reshape is supported for input nodes only."
);

assert_eq!(
input_shape.iter().product::<usize>(),
output_shape.iter().product::<usize>(),
Expand All @@ -144,6 +155,7 @@ where
output_shape,
padded_input_shape_log,
padded_output_shape_log,
node_type,
phantom: PhantomData,
}
}
Expand Down