diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index 0b66c93..2d116a9 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -1,15 +1,17 @@ use crate::{ model::{ - nodes::{loose_fc::LooseFCNode, reshape::ReshapeNode, Node}, + nodes::{loose_fc::LooseFCNode, reshape::ReshapeNode, Node, NodeType}, qarray::QArray, Model, Poly, - }, pcs_types::Brakedown, quantization::{quantise_f32_u8_nne, QSmallType} + }, + pcs_types::Brakedown, + quantization::{quantise_f32_u8_nne, QSmallType}, }; -use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge}; -use ark_poly_commit::PolynomialCommitment; use ark_bn254::Fr; +use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge}; use ark_ff::PrimeField; +use ark_poly_commit::PolynomialCommitment; mod input; mod parameters; @@ -27,10 +29,10 @@ where S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - let flat_dim = INPUT_DIMS.iter().product(); - let reshape: ReshapeNode = ReshapeNode::new(INPUT_DIMS.to_vec(), vec![flat_dim]); + let reshape: ReshapeNode = + ReshapeNode::new(INPUT_DIMS.to_vec(), vec![flat_dim], NodeType::Input); let lfc: LooseFCNode = LooseFCNode::new( WEIGHTS.to_vec(), @@ -43,9 +45,13 @@ where Z_W, S_O, Z_O, + NodeType::Output, ); - Model::new(INPUT_DIMS.to_vec(), vec![Node::Reshape(reshape), Node::LooseFC(lfc)]) + Model::new( + INPUT_DIMS.to_vec(), + vec![Node::Reshape(reshape), Node::LooseFC(lfc)], + ) } #[test] diff --git a/src/model/examples/two_layer_perceptron_mnist/mod.rs b/src/model/examples/two_layer_perceptron_mnist/mod.rs index ead1f88..7e3e1e6 100644 --- a/src/model/examples/two_layer_perceptron_mnist/mod.rs +++ b/src/model/examples/two_layer_perceptron_mnist/mod.rs @@ -1,15 +1,19 @@ use crate::{ model::{ - nodes::{fc::FCNode, loose_fc::LooseFCNode, relu::ReLUNode, reshape::ReshapeNode, Node}, + nodes::{ + fc::FCNode, loose_fc::LooseFCNode, relu::ReLUNode, reshape::ReshapeNode, Node, NodeType, + }, qarray::QArray, Model, Poly, - }, pcs_types::Brakedown, quantization::{quantise_f32_u8_nne, QSmallType} + }, + pcs_types::Brakedown, + quantization::{quantise_f32_u8_nne, QSmallType}, }; -use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge}; -use ark_poly_commit::PolynomialCommitment; use ark_bn254::Fr; +use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge}; use ark_ff::PrimeField; +use ark_poly_commit::PolynomialCommitment; mod input; mod parameters; @@ -28,10 +32,10 @@ where S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - let flat_dim = INPUT_DIMS.iter().product(); - let reshape: ReshapeNode = ReshapeNode::new(INPUT_DIMS.to_vec(), vec![flat_dim]); + let reshape: ReshapeNode = + ReshapeNode::new(INPUT_DIMS.to_vec(), vec![flat_dim], NodeType::Input); let lfc: LooseFCNode = LooseFCNode::new( WEIGHTS_1.to_vec(), @@ -44,9 +48,10 @@ where Z_1_W, S_1_O, Z_1_O, + NodeType::default(), ); - let relu: ReLUNode = ReLUNode::new(28, Z_1_O); + let relu: ReLUNode = ReLUNode::new(28, Z_1_O, NodeType::default()); let fc2: FCNode = FCNode::new( WEIGHTS_2.to_vec(), @@ -58,14 +63,18 @@ where Z_2_W, S_2_O, Z_2_O, + NodeType::Output, ); - Model::new(INPUT_DIMS.to_vec(), vec![ - Node::Reshape(reshape), - Node::LooseFC(lfc), - Node::ReLU(relu), - Node::FC(fc2), - ]) + Model::new( + INPUT_DIMS.to_vec(), + vec![ + Node::Reshape(reshape), + Node::LooseFC(lfc), + Node::ReLU(relu), + Node::FC(fc2), + ], + ) } #[test] diff --git a/src/model/nodes/fc.rs b/src/model/nodes/fc.rs index 8708781..792f327 100644 --- a/src/model/nodes/fc.rs +++ b/src/model/nodes/fc.rs @@ -13,7 +13,7 @@ use crate::quantization::{ }; use crate::{Commitment, CommitmentState}; -use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK, NodeProof}; +use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK, NodeProof, NodeType}; // TODO convention: input, bias and output are rows, the op is vec-by-mat (in that order) @@ -34,6 +34,8 @@ pub(crate) struct FCNode { /// Quantisation info used for both result computation and requantisation q_info: FCQInfo, + node_type: NodeType, + phantom: PhantomData<(F, S, PCS)>, } @@ -253,6 +255,7 @@ where z_w: QSmallType, s_o: QScaleType, z_o: QSmallType, + node_type: NodeType, ) -> Self { assert_eq!( weights.len(), @@ -308,6 +311,7 @@ where dims, padded_dims_log, q_info, + node_type, phantom: PhantomData, } } diff --git a/src/model/nodes/loose_fc.rs b/src/model/nodes/loose_fc.rs index 31ecac7..315761c 100644 --- a/src/model/nodes/loose_fc.rs +++ b/src/model/nodes/loose_fc.rs @@ -13,7 +13,7 @@ use crate::quantization::{ }; use crate::{Commitment, CommitmentState}; -use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK}; +use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK, NodeType}; // TODO convention: input, bias and output are rows, the op is vec-by-mat (in that order) @@ -35,6 +35,8 @@ pub(crate) struct LooseFCNode { /// Quantisation info used for both result computation and requantisation q_info: FCQInfo, + node_type: NodeType, + phantom: PhantomData<(F, S, PCS)>, } @@ -256,6 +258,7 @@ where z_w: QSmallType, s_o: QScaleType, z_o: QSmallType, + node_type: NodeType, ) -> Self { assert_eq!( weights.len(), @@ -317,6 +320,7 @@ where dims, padded_dims_log, q_info, + node_type, phantom: PhantomData, } } diff --git a/src/model/nodes/mod.rs b/src/model/nodes/mod.rs index 764a2e8..19d05eb 100644 --- a/src/model/nodes/mod.rs +++ b/src/model/nodes/mod.rs @@ -25,9 +25,13 @@ pub(crate) mod reshape; // mod parser; -// TODO: batched methods (e.g. for multiple evaluations) -// TODO: issue: missing info about size of the next output? Or reduplicate it? -// TODO way to handle generics more elegantly? or perhaps polynomials can be made ML directly? +#[derive(Default, PartialEq, Eq, Debug)] +pub enum NodeType { + Input, + Output, + #[default] + Middle, +} /// A node of the model including its transition function to the next node(s). /// It stores information about the transition (such as a matrix and bias, if diff --git a/src/model/nodes/relu.rs b/src/model/nodes/relu.rs index 2f6b5be..5f8e969 100644 --- a/src/model/nodes/relu.rs +++ b/src/model/nodes/relu.rs @@ -10,7 +10,7 @@ use crate::model::qarray::QArray; use crate::model::Poly; use crate::quantization::QSmallType; -use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK}; +use super::{Node, NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK, NodeType}; // Rectified linear unit node performing x |-> max(0, x). pub(crate) struct ReLUNode @@ -22,6 +22,7 @@ where num_units: usize, log_num_units: usize, zero_point: QSmallType, + node_type: NodeType, phantom: PhantomData<(F, S, PCS)>, } @@ -89,13 +90,14 @@ where S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - pub(crate) fn new(num_units: usize, zero_point: QSmallType) -> Self { + pub(crate) fn new(num_units: usize, zero_point: QSmallType, node_type: NodeType) -> Self { let log_num_units = log2(num_units.next_power_of_two()) as usize; Self { num_units, log_num_units, zero_point, + node_type, phantom: PhantomData, } } diff --git a/src/model/nodes/reshape.rs b/src/model/nodes/reshape.rs index 1087b96..8954fe7 100644 --- a/src/model/nodes/reshape.rs +++ b/src/model/nodes/reshape.rs @@ -10,7 +10,7 @@ use crate::model::qarray::QArray; use crate::model::Poly; use crate::quantization::QSmallType; -use super::{NodeOps, NodeOpsSNARK, NodeProof}; +use super::{NodeOps, NodeOpsSNARK, NodeProof, NodeType}; pub(crate) struct ReshapeNode where @@ -22,6 +22,7 @@ where output_shape: Vec, padded_input_shape_log: Vec, padded_output_shape_log: Vec, + node_type: NodeType, phantom: PhantomData<(F, S, PCS)>, } @@ -122,7 +123,17 @@ where S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - pub(crate) fn new(input_shape: Vec, output_shape: Vec) -> Self { + pub(crate) fn new( + input_shape: Vec, + output_shape: Vec, + node_type: NodeType, + ) -> Self { + assert_eq!( + node_type, + NodeType::Input, + "Currently reshape is supported for input nodes only." + ); + assert_eq!( input_shape.iter().product::(), output_shape.iter().product::(), @@ -144,6 +155,7 @@ where output_shape, padded_input_shape_log, padded_output_shape_log, + node_type, phantom: PhantomData, } }