-
Notifications
You must be signed in to change notification settings - Fork 0
Dimitris/requantize bmm ref #78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 51 commits
50352aa
78c80af
8618e14
c5fea81
e8f6154
9082cf7
c35f20a
661c03a
93ad93e
e38cf8a
73aba23
30b0e12
39b9f70
15929af
0b0d55d
c564446
5aaee3f
a6294e4
4428263
19234a9
c92a115
a71b5e1
c707f44
467829c
fc267a6
6d7bb9e
1b99e6e
de213f5
cfe5115
238e1d5
7d5eb68
9b0d9ed
0d10d9e
f2f41cc
366c4ea
6c685ef
a8a1307
5ad7abc
1bc3d7f
06e29d2
a83c4ec
fb5f32a
10e7acd
8718a79
d218dc6
d73633a
3f0c62e
0b94605
96c004d
2ac39f6
a9070de
07ef113
b05c109
f63a5a1
50fa7b9
f91ecad
3b8ac19
a22823c
e69ea18
454f18d
6d60ac4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| /target | ||
| /target |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,10 @@ | ||
| use crate::{BMMNode, Model, Node, Poly, QArray, RequantiseBMMNode, ReshapeNode}; | ||
| use crate::{ | ||
| model::nodes::{ | ||
| requantize_bmm_ref::RequantizeBMMRefNode, requantize_bmm_single::RequantizeBMMSingleNode, | ||
| }, | ||
| quantization::BMMRequantizationStrategy, | ||
| BMMNode, Model, Node, Poly, RequantizeBMMFloatNode, ReshapeNode, Tensor, | ||
| }; | ||
|
|
||
| use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; | ||
| use ark_ff::PrimeField; | ||
|
|
@@ -22,7 +28,9 @@ macro_rules! PATH { | |
| } | ||
|
|
||
| // TODO this is incorrect now that we have switched to logs | ||
| pub fn build_simple_perceptron_mnist<F, S, PCS>() -> Model<i8, i32> | ||
| pub fn build_simple_perceptron_mnist<F, S, PCS>( | ||
| req_strategy: BMMRequantizationStrategy, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so this means there's one requantization strategy per model, yes?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, basically these are just auxiliary functions that construct our example models by filling the vector of nodes with the nodes it should have. In one of the models there is only one requantisation node and in the other one there are two. In the latter case, we don't have any need to mix requantisation strategies (which sounds a bit unlikely anyway). Still, it should be stressed that this is just code to build examples, not library functionality code. Therefore the "lack of generality" shouldn't be a problem. Incidentally, the reason we added this argument was so that we could test out how much the reference, single-round and floating-point-based implementations of requantisation each differ from TF Lite execution. |
||
| ) -> Model<i8> | ||
| where | ||
| F: PrimeField + Absorb, | ||
| S: CryptographicSponge, | ||
|
|
@@ -32,20 +40,25 @@ where | |
|
|
||
| let reshape: ReshapeNode = ReshapeNode::new(INPUT_DIMS.to_vec(), vec![flat_dim]); | ||
|
|
||
| let w_array: QArray<i8> = QArray::read(&format!(PATH!(), "weights.json")); | ||
| let b_array: QArray<i32> = QArray::read(&format!(PATH!(), "bias.json")); | ||
| let w_array: Tensor<i8> = Tensor::read(&format!(PATH!(), "weights.json")); | ||
| let b_array: Tensor<i32> = Tensor::read(&format!(PATH!(), "bias.json")); | ||
|
|
||
| let bmm: BMMNode<i8, i32> = BMMNode::new(w_array, b_array, Z_I); | ||
| let bmm: BMMNode<i8> = BMMNode::new(w_array, b_array, Z_I); | ||
|
|
||
| let req_bmm: RequantiseBMMNode<i8> = | ||
| RequantiseBMMNode::new(OUTPUT_DIM, S_I, Z_I, S_W, Z_W, S_O, Z_O); | ||
| let req_bmm = match req_strategy { | ||
| BMMRequantizationStrategy::Floating => Node::RequantizeBMMFloat( | ||
| RequantizeBMMFloatNode::new(OUTPUT_DIM, S_I, Z_I, S_W, Z_W, S_O, Z_O), | ||
| ), | ||
| BMMRequantizationStrategy::Reference => { | ||
| Node::RequantizeBMMRef(RequantizeBMMRefNode::new(OUTPUT_DIM, S_I, S_W, S_O, Z_O)) | ||
| } | ||
| BMMRequantizationStrategy::SingleRound => { | ||
| Node::RequantizeBMMSingle(RequantizeBMMSingleNode::new(OUTPUT_DIM, S_I, S_W, S_O, Z_O)) | ||
| } | ||
| }; | ||
|
|
||
| Model::new( | ||
| INPUT_DIMS.to_vec(), | ||
| vec![ | ||
| Node::Reshape(reshape), | ||
| Node::BMM(bmm), | ||
| Node::RequantiseBMM(req_bmm), | ||
| ], | ||
| vec![Node::Reshape(reshape), Node::BMM(bmm), req_bmm], | ||
| ) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,13 @@ use ark_poly_commit::PolynomialCommitment; | |
| pub mod parameters; | ||
| use parameters::*; | ||
|
|
||
| use crate::{BMMNode, Model, Node, Poly, QArray, ReLUNode, RequantiseBMMNode, ReshapeNode}; | ||
| use crate::{ | ||
| model::nodes::{ | ||
| requantize_bmm_ref::RequantizeBMMRefNode, requantize_bmm_single::RequantizeBMMSingleNode, | ||
| }, | ||
| quantization::BMMRequantizationStrategy, | ||
| BMMNode, Model, Node, Poly, ReLUNode, RequantizeBMMFloatNode, ReshapeNode, Tensor, | ||
| }; | ||
|
|
||
| pub const INPUT_DIMS: &[usize] = &[28, 28]; | ||
| pub const INTER_DIM: usize = 28; | ||
|
|
@@ -22,7 +28,9 @@ macro_rules! PATH { | |
| }; | ||
| } | ||
|
|
||
| pub fn build_two_layer_perceptron_mnist<F, S, PCS>() -> Model<i8, i32> | ||
| pub fn build_two_layer_perceptron_mnist<F, S, PCS>( | ||
| req_strategy: BMMRequantizationStrategy, | ||
| ) -> Model<i8> | ||
| where | ||
| F: PrimeField + Absorb, | ||
| S: CryptographicSponge, | ||
|
|
@@ -32,32 +40,50 @@ where | |
|
|
||
| let reshape: ReshapeNode = ReshapeNode::new(INPUT_DIMS.to_vec(), vec![flat_dim]); | ||
|
|
||
| let w1_array: QArray<i8> = QArray::read(&format!(PATH!(), "weights_1.json")); | ||
| let b1_array: QArray<i32> = QArray::read(&format!(PATH!(), "bias_1.json")); | ||
| let w2_array: QArray<i8> = QArray::read(&format!(PATH!(), "weights_2.json")); | ||
| let b2_array: QArray<i32> = QArray::read(&format!(PATH!(), "bias_2.json")); | ||
| let w1_array: Tensor<i8> = Tensor::read(&format!(PATH!(), "weights_1.json")); | ||
| let b1_array: Tensor<i32> = Tensor::read(&format!(PATH!(), "bias_1.json")); | ||
| let w2_array: Tensor<i8> = Tensor::read(&format!(PATH!(), "weights_2.json")); | ||
| let b2_array: Tensor<i32> = Tensor::read(&format!(PATH!(), "bias_2.json")); | ||
|
|
||
| let bmm_1: BMMNode<i8, i32> = BMMNode::new(w1_array, b1_array, Z_1_I); | ||
| let bmm_1: BMMNode<i8> = BMMNode::new(w1_array, b1_array, Z_1_I); | ||
|
|
||
| let req_bmm_1: RequantiseBMMNode<i8> = | ||
| RequantiseBMMNode::new(INTER_DIM, S_1_I, Z_1_I, S_1_W, Z_1_W, S_1_O, Z_1_O); | ||
| let req_bmm_1 = match req_strategy { | ||
| BMMRequantizationStrategy::Floating => Node::RequantizeBMMFloat( | ||
| RequantizeBMMFloatNode::new(INTER_DIM, S_1_I, Z_1_I, S_1_W, Z_1_W, S_1_O, Z_1_O), | ||
| ), | ||
| BMMRequantizationStrategy::Reference => Node::RequantizeBMMRef(RequantizeBMMRefNode::new( | ||
| INTER_DIM, S_1_I, S_1_W, S_1_O, Z_1_O, | ||
| )), | ||
| BMMRequantizationStrategy::SingleRound => Node::RequantizeBMMSingle( | ||
| RequantizeBMMSingleNode::new(INTER_DIM, S_1_I, S_1_W, S_1_O, Z_1_O), | ||
| ), | ||
| }; | ||
|
|
||
| let relu: ReLUNode<i8> = ReLUNode::new(28, Z_1_O); | ||
|
|
||
| let bmm_2: BMMNode<i8, i32> = BMMNode::new(w2_array, b2_array, Z_2_I); | ||
| let bmm_2: BMMNode<i8> = BMMNode::new(w2_array, b2_array, Z_2_I); | ||
|
|
||
| let req_bmm_2: RequantiseBMMNode<i8> = | ||
| RequantiseBMMNode::new(OUTPUT_DIM, S_2_I, Z_2_I, S_2_W, Z_2_W, S_2_O, Z_2_O); | ||
| let req_bmm_2 = match req_strategy { | ||
| BMMRequantizationStrategy::Floating => Node::RequantizeBMMFloat( | ||
| RequantizeBMMFloatNode::new(OUTPUT_DIM, S_2_I, Z_2_I, S_2_W, Z_2_W, S_2_O, Z_2_O), | ||
| ), | ||
| BMMRequantizationStrategy::Reference => Node::RequantizeBMMRef(RequantizeBMMRefNode::new( | ||
| OUTPUT_DIM, S_2_I, S_2_W, S_2_O, Z_2_O, | ||
| )), | ||
| BMMRequantizationStrategy::SingleRound => Node::RequantizeBMMSingle( | ||
| RequantizeBMMSingleNode::new(OUTPUT_DIM, S_2_I, S_2_W, S_2_O, Z_2_O), | ||
| ), | ||
| }; | ||
|
||
|
|
||
| Model::new( | ||
| INPUT_DIMS.to_vec(), | ||
| vec![ | ||
| Node::Reshape(reshape), | ||
| Node::BMM(bmm_1), | ||
| Node::RequantiseBMM(req_bmm_1), | ||
| req_bmm_1, | ||
| Node::ReLU(relu), | ||
| Node::BMM(bmm_2), | ||
| Node::RequantiseBMM(req_bmm_2), | ||
| req_bmm_2, | ||
| ], | ||
| ) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this revision is already in master, and I think we should stick to
revs no?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @mmagician, I think you're right. In any case, you can regard the
Cargo.tomlhere as a temporary thing: the one inmainwas recently cleaned up, pinned to specific revs and made much more consistent (by @Cesar199999, I think).Things will be much cleaner when
mainis merged into this PR (which, if I recall correctly, @Cesar199999 will take care of?). In principle that will happen very soon, when the remaining tiny open discussions are concluded.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, on it!