diff --git a/README.md b/README.md index a0e15d9..948c48b 100644 --- a/README.md +++ b/README.md @@ -18,13 +18,13 @@ where `` is one of the following: In order to run any tests involving python code, such as compatibility tests with TF Lite, the feature `python` must be activated (which automatically enables `test-types`). -## From `ndarray` to `QArray` +## From `ndarray` to `Tensor` -In order to save a `numpy` `ndarray` (python side) as a serialised JSON which can be directly read into a `QArray` of ours (Rust side), +In order to save a `numpy` `ndarray` (python side) as a serialised JSON which can be directly read into a `Tensor` of ours (Rust side), - Convert the `ndarray` into an `OrderedDict` using our custom python function `tensor_to_dict` (available in several of the python notebooks) - Pass the resulting `OrderedDict` together with the destination path to `json.dump`. -The saved JSON file can be deserialised over in Rust with `QArray::read(path: &str) -> QArray`. If instead of a single `OrderedDict`, a python list of `OrderedDict`s is passed to `json.dump`, the resulting file can be deserialised with `QArray::read_list(path: &str) -> Vec `. +The saved JSON file can be deserialised over in Rust with `Tensor::read(path: &str) -> Tensor`. If instead of a single `OrderedDict`, a python list of `OrderedDict`s is passed to `json.dump`, the resulting file can be deserialised with `Tensor::read_list(path: &str) -> Vec `. Cf. `exploring_tf_lite/training_two_layer_perceptron.ipynb` for example usage. diff --git a/common/examples/common/lib.rs b/common/examples/common/lib.rs index 8945c28..4ee8a52 100644 --- a/common/examples/common/lib.rs +++ b/common/examples/common/lib.rs @@ -1,4 +1,4 @@ -use hcs_common::{quantise_f32_u8_nne, Model, Poly, QArray}; +use hcs_common::{quantise_f32_u8_nne, Model, Poly, Tensor}; use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; @@ -6,16 +6,16 @@ use ark_poly_commit::PolynomialCommitment; // Auxiliary function fn unpadded_inference( - raw_input: QArray, - model: &Model, + raw_input: Tensor, + model: &Model, qinfo: (f32, u8), -) -> QArray +) -> Tensor where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - let quantised_input: QArray = QArray::new( + let quantised_input: Tensor = Tensor::new( quantise_f32_u8_nne(raw_input.values(), qinfo.0, qinfo.1), raw_input.shape().clone(), ); @@ -31,16 +31,16 @@ where // If padded inference is left on the prover side, move this to the prover /* // Auxiliary function fn padded_inference( - raw_input: QArray, - model: &Model, + raw_input: Tensor, + model: &Model, qinfo: (f32, u8), -) -> QArray +) -> Tensor where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - let quantised_input: QArray = QArray::new( + let quantised_input: Tensor = Tensor::new( quantise_f32_u8_nne(raw_input.values(), qinfo.0, qinfo.1), raw_input.shape().clone(), ); @@ -48,7 +48,7 @@ where let input_i8 = (quantised_input.cast::() - 128).cast::(); let output_i8 = - as ProveModel>::padded_evaluate(model, input_i8); + as ProveModel>::padded_evaluate(model, input_i8); (output_i8.cast::() + 128).cast() } */ @@ -56,15 +56,15 @@ where pub fn run_unpadded( input_path: &str, expected_output_path: &str, - model: &Model, + model: &Model, qinfo: (f32, u8), ) where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - let raw_input: QArray = QArray::read(input_path); - let expected_output: QArray = QArray::read(expected_output_path); + let raw_input: Tensor = Tensor::read(input_path); + let expected_output: Tensor = Tensor::read(expected_output_path); let output_u8 = unpadded_inference::(raw_input, model, qinfo); @@ -78,15 +78,15 @@ pub fn run_unpadded( /* pub fn run_padded( input_path: &str, expected_output_path: &str, - model: &Model, + model: &Model, qinfo: (f32, u8), ) where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - let raw_input: QArray = QArray::read(input_path); - let expected_output: QArray = QArray::read(expected_output_path); + let raw_input: Tensor = Tensor::read(input_path); + let expected_output: Tensor = Tensor::read(expected_output_path); let output_u8 = padded_inference::(raw_input, model, qinfo); @@ -98,15 +98,15 @@ pub fn run_unpadded( pub fn multi_run_unpadded( inputs_path: &str, expected_outputs_path: &str, - model: &Model, + model: &Model, qinfo: (f32, u8), ) where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - let raw_inputs: Vec> = QArray::read_list(inputs_path); - let expected_outputs: Vec> = QArray::read_list(expected_outputs_path); + let raw_inputs: Vec> = Tensor::read_list(inputs_path); + let expected_outputs: Vec> = Tensor::read_list(expected_outputs_path); for (raw_input, expected_output) in raw_inputs.into_iter().zip(expected_outputs.into_iter()) { assert_eq!( @@ -131,8 +131,8 @@ pub fn multi_run_padded( S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - let raw_inputs: Vec> = QArray::read_list(inputs_path); - let expected_outputs: Vec> = QArray::read_list(expected_outputs_path); + let raw_inputs: Vec> = Tensor::read_list(inputs_path); + let expected_outputs: Vec> = Tensor::read_list(expected_outputs_path); for (raw_input, expected_output) in raw_inputs.into_iter().zip(expected_outputs.into_iter()) { assert_eq!( diff --git a/common/examples/simple_perceptron_mnist/main.rs b/common/examples/simple_perceptron_mnist/main.rs index 5b5806a..68ae93a 100644 --- a/common/examples/simple_perceptron_mnist/main.rs +++ b/common/examples/simple_perceptron_mnist/main.rs @@ -1,6 +1,6 @@ use hcs_common::{ simple_perceptron_mnist::{build_simple_perceptron_mnist, parameters::*}, - Ligero, + BMMRequantizationStrategy, Ligero, }; use ark_bn254::Fr; @@ -20,7 +20,9 @@ macro_rules! PATH { } fn main() { - let simple_perceptron = build_simple_perceptron_mnist::, Ligero>(); + let simple_perceptron = build_simple_perceptron_mnist::, Ligero>( + BMMRequantizationStrategy::Floating, + ); // Right now this can't be QInfo because the latter is always a pair // (f32, i8), which indeed matches in-model quantisation, but not diff --git a/common/examples/two_layer_perceptron_mnist/main.rs b/common/examples/two_layer_perceptron_mnist/main.rs index 5340902..cbb6e80 100644 --- a/common/examples/two_layer_perceptron_mnist/main.rs +++ b/common/examples/two_layer_perceptron_mnist/main.rs @@ -1,6 +1,6 @@ use hcs_common::{ two_layer_perceptron_mnist::{build_two_layer_perceptron_mnist, parameters::*}, - Ligero, + BMMRequantizationStrategy, Ligero, }; use ark_bn254::Fr; @@ -20,8 +20,9 @@ macro_rules! PATH { } fn main() { - let two_layer_perceptron = - build_two_layer_perceptron_mnist::, Ligero>(); + let two_layer_perceptron = build_two_layer_perceptron_mnist::, Ligero>( + BMMRequantizationStrategy::Floating, + ); // Right now this can't be QInfo because the latter is always a pair // (f32, i8), which indeed matches in-model quantisation, but not diff --git a/common/src/compatibility/example_models/simple_perceptron_mnist/mod.rs b/common/src/compatibility/example_models/simple_perceptron_mnist/mod.rs index bafb77e..48d8ca6 100644 --- a/common/src/compatibility/example_models/simple_perceptron_mnist/mod.rs +++ b/common/src/compatibility/example_models/simple_perceptron_mnist/mod.rs @@ -1,4 +1,7 @@ -use crate::{BMMNode, Model, Node, Poly, QArray, RequantiseBMMNode, ReshapeNode}; +use crate::{ + quantization::BMMRequantizationStrategy, utils::req_bmm_from_strategy, BMMNode, Model, Node, + Poly, ReshapeNode, Tensor, +}; use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; @@ -22,7 +25,9 @@ macro_rules! PATH { } // TODO this is incorrect now that we have switched to logs -pub fn build_simple_perceptron_mnist() -> Model +pub fn build_simple_perceptron_mnist( + req_strategy: BMMRequantizationStrategy, +) -> Model where F: PrimeField + Absorb, S: CryptographicSponge, @@ -32,20 +37,15 @@ where let reshape: ReshapeNode = ReshapeNode::new(INPUT_DIMS.to_vec(), vec![flat_dim]); - let w_array: QArray = QArray::read(&format!(PATH!(), "weights.json")); - let b_array: QArray = QArray::read(&format!(PATH!(), "bias.json")); + let w_array: Tensor = Tensor::read(&format!(PATH!(), "weights.json")); + let b_array: Tensor = Tensor::read(&format!(PATH!(), "bias.json")); - let bmm: BMMNode = BMMNode::new(w_array, b_array, Z_I); + let bmm: BMMNode = BMMNode::new(w_array, b_array, Z_I); - let req_bmm: RequantiseBMMNode = - RequantiseBMMNode::new(OUTPUT_DIM, S_I, Z_I, S_W, Z_W, S_O, Z_O); + let req_bmm = req_bmm_from_strategy(req_strategy, OUTPUT_DIM, S_I, Z_I, S_W, Z_W, S_O, Z_O); Model::new( INPUT_DIMS.to_vec(), - vec![ - Node::Reshape(reshape), - Node::BMM(bmm), - Node::RequantiseBMM(req_bmm), - ], + vec![Node::Reshape(reshape), Node::BMM(bmm), req_bmm], ) } diff --git a/common/src/compatibility/example_models/two_layer_perceptron_mnist/mod.rs b/common/src/compatibility/example_models/two_layer_perceptron_mnist/mod.rs index 74a11c4..aa89b20 100644 --- a/common/src/compatibility/example_models/two_layer_perceptron_mnist/mod.rs +++ b/common/src/compatibility/example_models/two_layer_perceptron_mnist/mod.rs @@ -5,7 +5,10 @@ use ark_poly_commit::PolynomialCommitment; pub mod parameters; use parameters::*; -use crate::{BMMNode, Model, Node, Poly, QArray, ReLUNode, RequantiseBMMNode, ReshapeNode}; +use crate::{ + quantization::BMMRequantizationStrategy, utils::req_bmm_from_strategy, BMMNode, Model, Node, + Poly, ReLUNode, ReshapeNode, Tensor, +}; pub const INPUT_DIMS: &[usize] = &[28, 28]; pub const INTER_DIM: usize = 28; @@ -22,7 +25,9 @@ macro_rules! PATH { }; } -pub fn build_two_layer_perceptron_mnist() -> Model +pub fn build_two_layer_perceptron_mnist( + req_strategy: BMMRequantizationStrategy, +) -> Model where F: PrimeField + Absorb, S: CryptographicSponge, @@ -32,32 +37,48 @@ where let reshape: ReshapeNode = ReshapeNode::new(INPUT_DIMS.to_vec(), vec![flat_dim]); - let w1_array: QArray = QArray::read(&format!(PATH!(), "weights_1.json")); - let b1_array: QArray = QArray::read(&format!(PATH!(), "bias_1.json")); - let w2_array: QArray = QArray::read(&format!(PATH!(), "weights_2.json")); - let b2_array: QArray = QArray::read(&format!(PATH!(), "bias_2.json")); + let w1_array: Tensor = Tensor::read(&format!(PATH!(), "weights_1.json")); + let b1_array: Tensor = Tensor::read(&format!(PATH!(), "bias_1.json")); + let w2_array: Tensor = Tensor::read(&format!(PATH!(), "weights_2.json")); + let b2_array: Tensor = Tensor::read(&format!(PATH!(), "bias_2.json")); - let bmm_1: BMMNode = BMMNode::new(w1_array, b1_array, Z_1_I); + let bmm_1: BMMNode = BMMNode::new(w1_array, b1_array, Z_1_I); - let req_bmm_1: RequantiseBMMNode = - RequantiseBMMNode::new(INTER_DIM, S_1_I, Z_1_I, S_1_W, Z_1_W, S_1_O, Z_1_O); + let req_bmm_1 = req_bmm_from_strategy( + req_strategy, + INTER_DIM, + S_1_I, + Z_1_I, + S_1_W, + Z_1_W, + S_1_O, + Z_1_O, + ); let relu: ReLUNode = ReLUNode::new(28, Z_1_O); - let bmm_2: BMMNode = BMMNode::new(w2_array, b2_array, Z_2_I); + let bmm_2: BMMNode = BMMNode::new(w2_array, b2_array, Z_2_I); - let req_bmm_2: RequantiseBMMNode = - RequantiseBMMNode::new(OUTPUT_DIM, S_2_I, Z_2_I, S_2_W, Z_2_W, S_2_O, Z_2_O); + let req_bmm_2 = req_bmm_from_strategy( + req_strategy, + OUTPUT_DIM, + S_2_I, + Z_2_I, + S_2_W, + Z_2_W, + S_2_O, + Z_2_O, + ); Model::new( INPUT_DIMS.to_vec(), vec![ Node::Reshape(reshape), Node::BMM(bmm_1), - Node::RequantiseBMM(req_bmm_1), + req_bmm_1, Node::ReLU(relu), Node::BMM(bmm_2), - Node::RequantiseBMM(req_bmm_2), + req_bmm_2, ], ) } diff --git a/common/src/compatibility/python/mod.rs b/common/src/compatibility/python/mod.rs index c872d71..240a813 100644 --- a/common/src/compatibility/python/mod.rs +++ b/common/src/compatibility/python/mod.rs @@ -5,7 +5,7 @@ use std::{fs::create_dir_all, path::Path}; use pyo3::{prelude::*, PyAny}; -use crate::QArray; +use crate::Tensor; const PERCEPTRON_PATH: &str = include_str!(concat!( env!("CARGO_MANIFEST_DIR"), @@ -21,7 +21,7 @@ pub fn get_model(py: Python, model_name: &str, args: Option>) func.call1(py, (model_name, args)).unwrap() } -pub fn save_model_parameters_as_qarray(py: Python, model: &Py, path: &str) { +pub fn save_model_parameters_as_tensor(py: Python, model: &Py, path: &str) { let path = Path::new(path); if !path.exists() { @@ -30,13 +30,13 @@ pub fn save_model_parameters_as_qarray(py: Python, model: &Py, path: &str } model - .call_method1(py, "save_params_as_qarray", (path,)) + .call_method1(py, "save_params_as_verifiaml_tensor", (path,)) .unwrap(); } -pub fn get_model_input<'py, T>(python: Python<'py>, model: &Py, index: usize) -> QArray +pub fn get_model_input<'py, T>(python: Python<'py>, model: &Py, index: usize) -> Tensor where - T: Into> + FromPyObject<'py> + Clone, + T: Into> + FromPyObject<'py> + Clone, { let result = model.call_method1(python, "get_input", (index,)); @@ -46,10 +46,10 @@ where model_input.into() } -pub fn get_model_output(py: Python, model: &Py, index: usize) -> QArray { +pub fn get_model_output(py: Python, model: &Py, index: usize) -> Tensor { let result = model.call_method1(py, "get_output", (index,)); // Downcast the result to the expected type let model_output = result.unwrap().extract::>(py).unwrap(); - QArray::from(model_output) + Tensor::from(model_output) } diff --git a/common/src/compatibility/python/tests.rs b/common/src/compatibility/python/tests.rs index 7fc7322..defa7fe 100644 --- a/common/src/compatibility/python/tests.rs +++ b/common/src/compatibility/python/tests.rs @@ -17,7 +17,7 @@ use crate::{ }, }, }, - quantise_f32_u8_nne, Ligero, Model, QArray, + quantise_f32_u8_nne, BMMRequantizationStrategy, Ligero, Model, Tensor, }; use ark_bn254::Fr; use ark_crypto_primitives::sponge::poseidon::PoseidonSponge; @@ -30,12 +30,8 @@ const NB_OUTPUTS: usize = 10000; // within the allowed error. const ALLOWED_ERROR_MARGIN: f32 = 0.1; -fn unpadded_inference( - raw_input: QArray, - model: &Model, - qinfo: (f32, u8), -) -> QArray { - let quantised_input: QArray = QArray::new( +fn unpadded_inference(raw_input: Tensor, model: &Model, qinfo: (f32, u8)) -> Tensor { + let quantised_input: Tensor = Tensor::new( quantise_f32_u8_nne(raw_input.values(), qinfo.0, qinfo.1), raw_input.shape().clone(), ); @@ -47,9 +43,52 @@ fn unpadded_inference( (output_i8.cast::() + 128).cast() } +fn run_model_all_outputs(model_name: &str, req_strategy: BMMRequantizationStrategy) { + let correct_samples: usize = Python::with_gil(|py| { + let (s_input, z_input, rust_model) = match model_name { + "QSimplePerceptron" => ( + S_INPUT_SIMPLE_PERCEPTRON_MNIST, + Z_INPUT_SIMPLE_PERCEPTRON_MNIST, + build_simple_perceptron_mnist::, Ligero>(req_strategy), + ), + "QTwoLayerPerceptron" => ( + S_INPUT_TWO_LAYER_PERCEPTRON_MNIST, + Z_INPUT_TWO_LAYER_PERCEPTRON_MNIST, + build_two_layer_perceptron_mnist::, Ligero>( + req_strategy, + ), + ), + _ => panic!("Model not found"), + }; + let tf_lite_model = get_model(py, model_name, None); + (0..NB_OUTPUTS) + .into_iter() + .map(|i| { + let raw_input = get_model_input::>>(py, &tf_lite_model, i); + let expected_output = get_model_output(py, &tf_lite_model, i); + let output = unpadded_inference(raw_input, &rust_model, (s_input, z_input)); + (output == expected_output) as usize + }) + .sum() + }); + + println!( + "{} with requantization strategy {:?}, discrepancies: {} out of {}", + model_name, + req_strategy, + NB_OUTPUTS - correct_samples, + NB_OUTPUTS + ); + + assert_ge!( + correct_samples as f32 / NB_OUTPUTS as f32, + 1.0 - ALLOWED_ERROR_MARGIN + ); +} + #[test] fn test_get_model_input() { - let expected_input = QArray::read("examples/simple_perceptron_mnist/data/input_test_150.json"); + let expected_input = Tensor::read("examples/simple_perceptron_mnist/data/input_test_150.json"); assert_eq!( Python::with_gil(|py| get_model_input::>>( py, @@ -63,7 +102,7 @@ fn test_get_model_input() { #[test] fn test_simple_perceptron_mnist_single_output() { let expected_output = - QArray::read("examples/simple_perceptron_mnist/data/output_test_150.json"); + Tensor::read("examples/simple_perceptron_mnist/data/output_test_150.json"); assert_eq!( Python::with_gil(|py| get_model_output(py, &get_model(py, "QSimplePerceptron", None), 150)), expected_output @@ -73,7 +112,7 @@ fn test_simple_perceptron_mnist_single_output() { #[test] fn test_two_layer_perceptron_mnist_single_output() { let expected_output = - QArray::read("examples/two_layer_perceptron_mnist/data/output_test_150.json"); + Tensor::read("examples/two_layer_perceptron_mnist/data/output_test_150.json"); assert_eq!( Python::with_gil(|py| get_model_output( py, @@ -85,79 +124,34 @@ fn test_two_layer_perceptron_mnist_single_output() { } #[test] -fn test_simple_perceptron_mnist_all_outputs() { - let simple_perceptron_mnist = - build_simple_perceptron_mnist::, Ligero>(); - - let correct_samples: usize = Python::with_gil(|py| { - let tf_lite_model = get_model(py, "QSimplePerceptron", None); - (0..NB_OUTPUTS) - .into_iter() - .map(|i| { - let raw_input = get_model_input::>>(py, &tf_lite_model, i); - let expected_output = get_model_output(py, &tf_lite_model, i); - - let output = unpadded_inference( - raw_input, - &simple_perceptron_mnist, - ( - S_INPUT_SIMPLE_PERCEPTRON_MNIST, - Z_INPUT_SIMPLE_PERCEPTRON_MNIST, - ), - ); - - (output == expected_output) as usize - }) - .sum() - }); - - println!( - "Simple perceptron discrepancies: {} out of {}", - NB_OUTPUTS - correct_samples, - NB_OUTPUTS - ); - - assert_ge!( - correct_samples as f32 / NB_OUTPUTS as f32, - 1.0 - ALLOWED_ERROR_MARGIN - ); +fn test_simple_perceptron_req_float() { + run_model_all_outputs("QSimplePerceptron", BMMRequantizationStrategy::Floating); } #[test] -fn test_two_layer_perceptron_mnist_all_outputs() { - let two_layer_perceptron_mnist = - build_two_layer_perceptron_mnist::, Ligero>(); - - let correct_samples: usize = Python::with_gil(|py| { - let tf_lite_model = get_model(py, "QTwoLayerPerceptron", None); - (0..NB_OUTPUTS) - .into_iter() - .map(|i| { - let raw_input = get_model_input::>>(py, &tf_lite_model, i); - let expected_output = get_model_output(py, &tf_lite_model, i); +fn test_two_layer_perceptron_req_float() { + run_model_all_outputs("QTwoLayerPerceptron", BMMRequantizationStrategy::Floating); +} - let output = unpadded_inference( - raw_input, - &two_layer_perceptron_mnist, - ( - S_INPUT_TWO_LAYER_PERCEPTRON_MNIST, - Z_INPUT_TWO_LAYER_PERCEPTRON_MNIST, - ), - ); +#[test] +fn test_simple_perceptron_req_ref() { + run_model_all_outputs("QSimplePerceptron", BMMRequantizationStrategy::Reference); +} - (output == expected_output) as usize - }) - .sum() - }); +#[test] +fn test_two_layer_perceptron_req_ref() { + run_model_all_outputs("QTwoLayerPerceptron", BMMRequantizationStrategy::Reference); +} - println!( - "Two-layer perceptron discrepancies: {} out of {}", - NB_OUTPUTS - correct_samples, - NB_OUTPUTS - ); +#[test] +fn test_simple_perceptron_req_single() { + run_model_all_outputs("QSimplePerceptron", BMMRequantizationStrategy::SingleRound); +} - assert_ge!( - correct_samples as f32 / NB_OUTPUTS as f32, - 1.0 - ALLOWED_ERROR_MARGIN +#[test] +fn test_two_layer_perceptron_req_single() { + run_model_all_outputs( + "QTwoLayerPerceptron", + BMMRequantizationStrategy::SingleRound, ); } diff --git a/common/src/lib.rs b/common/src/lib.rs index a30e5cb..8afefa6 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -13,17 +13,19 @@ trait Proof {} pub use model::nodes::{ bmm::{BMMNode, BMMNodeCommitment, BMMNodeCommitmentState, BMMNodeProof}, relu::ReLUNode, - requantise_bmm::{ - RequantiseBMMNode, RequantiseBMMNodeCommitment, RequantiseBMMNodeCommitmentState, - RequantiseBMMNodeProof, + requantize_bmm_float::{ + RequantizeBMMFloatNode, RequantizeBMMNodeCommitment, RequantizeBMMNodeCommitmentState, + RequantizeBMMNodeProof, }, reshape::ReshapeNode, Node, NodeCommitment, NodeCommitmentState, NodeOpsPadded, NodeProof, }; -pub use model::qarray::{InnerType, QArray, QTypeArray}; +pub use model::tensor::{Integral, NIOTensor, SmallNIO, Tensor}; pub use model::{InferenceProof, Model}; pub use model::{LabeledPoly, Poly}; -pub use quantization::{quantise_f32_u8_nne, requantise_fc, BMMQInfo, QInfo, RoundingScheme}; +pub use quantization::{ + quantise_f32_u8_nne, requantize_fc, BMMQInfo, BMMRequantizationStrategy, QInfo, RoundingScheme, +}; #[cfg(feature = "test-types")] pub use utils::{pcs_types::Ligero, test_sponge::test_sponge}; diff --git a/common/src/model/mod.rs b/common/src/model/mod.rs index ba2dd3d..9cad5b8 100644 --- a/common/src/model/mod.rs +++ b/common/src/model/mod.rs @@ -1,3 +1,6 @@ +pub mod nodes; +pub mod tensor; + use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; use ark_poly::DenseMultilinearExtension; @@ -6,27 +9,24 @@ use ark_std::rand::RngCore; use crate::model::nodes::Node; -use self::qarray::InnerType; -use self::qarray::QTypeArray; -use self::{nodes::NodeProof, qarray::QArray}; - -pub mod nodes; -pub mod qarray; +use self::tensor::{NIOTensor, SmallNIO}; +use self::{nodes::NodeProof, tensor::Tensor}; pub type Poly = DenseMultilinearExtension; pub type LabeledPoly = LabeledPolynomial>; -pub struct InferenceProof +pub struct InferenceProof where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, + ST: SmallNIO, { // Model input tensors in plain - pub inputs: Vec>, + pub inputs: Vec>, // Model output tensors in plain - pub outputs: Vec>, + pub outputs: Vec>, // Commitments to each of the node values pub node_value_commitments: Vec>, @@ -45,18 +45,14 @@ where // TODO: for now, we require all nodes to use the same PCS; this might change // in the future -pub struct Model { +pub struct Model { pub input_shape: Vec, pub output_shape: Vec, - pub nodes: Vec>, + pub nodes: Vec>, } -impl Model -where - ST: InnerType + TryFrom, - LT: InnerType + From, -{ - pub fn new(input_shape: Vec, nodes: Vec>) -> Self { +impl Model { + pub fn new(input_shape: Vec, nodes: Vec>) -> Self { // An empty model would cause panics later down the line e.g. when // determining the number of variables needed to commit to it. assert!(!nodes.is_empty(), "A model cannot have no nodes",); @@ -77,7 +73,7 @@ where rng: &mut R, ) -> Result<(PCS::CommitterKey, PCS::VerifierKey), PCS::Error> where - F: PrimeField + Absorb + From + From, + F: PrimeField + Absorb + From + From, S: CryptographicSponge, PCS: PolynomialCommitment, S>, R: RngCore, @@ -92,8 +88,8 @@ where PCS::trim(&pp, 0, 0, None) } - pub fn evaluate(&self, input: QArray) -> QArray { - let mut output = QTypeArray::S(input); + pub fn evaluate(&self, input: Tensor) -> Tensor { + let mut output = NIOTensor::S(input); for node in &self.nodes { output = node.evaluate(&output); diff --git a/common/src/model/nodes/bmm.rs b/common/src/model/nodes/bmm.rs index 47be337..bd24246 100644 --- a/common/src/model/nodes/bmm.rs +++ b/common/src/model/nodes/bmm.rs @@ -5,7 +5,7 @@ use ark_std::log2; use ark_sumcheck::ml_sumcheck::Proof; -use crate::model::qarray::{InnerType, QArray}; +use crate::model::tensor::{Integral, SmallNIO, Tensor}; use crate::model::Poly; use crate::{Commitment, CommitmentState}; @@ -14,15 +14,15 @@ use super::{NodeOpsNative, NodeOpsPadded}; // TODO convention: input, bias and output are rows, the op is vec-by-mat (in that order) /// Start with 2D matrices, and Mat-by-vector multiplication only -pub struct BMMNode { +pub struct BMMNode { /// The row-major flattened unpadded vector of weights - weights: QArray, + weights: Tensor, /// The padded weight vector - pub padded_weights: QArray, + pub padded_weights: Tensor, /// The unpadded vector of biases - bias: QArray, + bias: Tensor, /// The padded bias vector - pub padded_bias: QArray, + pub padded_bias: Tensor, /// Unpadded imensions (rows, columns) dims: (usize, usize), /// The logarithm of the padded dimensions (rows, columns) @@ -103,16 +103,15 @@ pub struct BMMNodeProof< pub bias_opening_value: F, } -impl NodeOpsNative for BMMNode +impl NodeOpsNative for BMMNode where - ST: InnerType, - LT: InnerType + From, + ST: SmallNIO, { fn shape(&self) -> Vec { vec![self.dims.1] } - fn evaluate(&self, input: &QArray) -> QArray { + fn evaluate(&self, input: &Tensor) -> Tensor { // Sanity checks // TODO systematise assert_eq!( @@ -128,33 +127,32 @@ where input.len() ); - let input: QArray = input.cast(); + let input: Tensor = input.cast(); // TODO this is a bigger question: can this overflow an i8? Supposedly the point of quantisation // is that input-by-weight products can be computed in i8. To be safe, let us use the large type here - let shifted_input = input - LT::from(self.input_zero_point); + let shifted_input = input - self.input_zero_point.into(); let mut accumulators = self.bias.values().clone(); - // TODO this can be made more elegant (efficient?) using addition of QArrays after defining suitable operators + // TODO this can be made more elegant (efficient?) using addition of Tensors after defining suitable operators // TODO since we have acumulators, this can be done more efficiently going row-wise to avoid re-caching the input for col in 0..self.dims.1 { // TODO does the compiler realise it doesn't need to access accumulators[col] on every iteration of the inner loop? ow change for row in 0..self.dims.0 { accumulators[col] += - shifted_input[row] * LT::from(self.weights[row * self.dims.1 + col]) + shifted_input[row] * self.weights[row * self.dims.1 + col].into(); } } - QArray::new(accumulators, vec![self.dims.1]) + Tensor::new(accumulators, vec![self.dims.1]) } } -impl NodeOpsPadded for BMMNode +impl NodeOpsPadded for BMMNode where - ST: InnerType + TryFrom, - LT: InnerType + From, + ST: SmallNIO, { fn padded_shape_log(&self) -> Vec { vec![self.padded_dims_log.1] @@ -166,9 +164,9 @@ where // This function naively computes entries which are known to be zero. It is // meant to exactly mirror the proof-system multiplication proved by the - // sumcheck argument. Requantisation and shifting are also applied to these + // sumcheck argument. Requantization and shifting are also applied to these // trivial entries, as the proof system does. - fn padded_evaluate(&self, input: &QArray) -> QArray { + fn padded_evaluate(&self, input: &Tensor) -> Tensor { let padded_dims = (1 << self.padded_dims_log.0, 1 << self.padded_dims_log.1); // Sanity checks @@ -182,40 +180,39 @@ where assert_eq!( padded_dims.0, input.len(), - "Length mismatch: Padded fully connected node expected input with {} elements, got {} elements instead", + "Length mismatch: Padded BMM node expected input with {} elements, got {} elements instead", padded_dims.0, input.len() ); - let input: QArray = input.cast(); + let input: Tensor = input.cast(); // TODO this is a bigger question: can this overflow an i8? Supposedly the point of quantisation // is that input-by-weight products can be computed in i8. To be safe, let us use the large type here - let shifted_input = input - LT::from(self.input_zero_point); + let shifted_input = input - self.input_zero_point.into(); let mut accumulators = self.padded_bias.values().clone(); - // TODO this can be made more elegant (efficient?) using addition of QArrays after defining suitable operators + // TODO this can be made more elegant (efficient?) using addition of Tensors after defining suitable operators // TODO since we have acumulators, this can be done more efficiently going row-wise to avoid re-caching the input for col in 0..padded_dims.1 { // TODO does the compiler realise it doesn't need to access accumulators[col] on every iteration of the inner loop? ow change for row in 0..padded_dims.0 { accumulators[col] += - shifted_input[row] * LT::from(self.padded_weights[row * padded_dims.1 + col]) + shifted_input[row] * self.padded_weights[row * padded_dims.1 + col].into(); } } - QArray::new(accumulators, vec![padded_dims.1]) + Tensor::new(accumulators, vec![padded_dims.1]) } } -impl BMMNode +impl BMMNode where - ST: InnerType, - LT: InnerType, + ST: SmallNIO, { - pub fn new(weights: QArray, bias: QArray, input_zero_point: ST) -> Self { + pub fn new(weights: Tensor, bias: Tensor, input_zero_point: ST) -> Self { let dims = (weights.shape()[0], weights.shape()[1]); assert_eq!( @@ -237,7 +234,7 @@ where let padded_bias = bias .clone() - .compact_resize(vec![dims.1.next_power_of_two()], LT::ZERO); + .compact_resize(vec![dims.1.next_power_of_two()], ST::LT::ZERO); Self { weights, diff --git a/common/src/model/nodes/mod.rs b/common/src/model/nodes/mod.rs index 6d8994b..e8978fd 100644 --- a/common/src/model/nodes/mod.rs +++ b/common/src/model/nodes/mod.rs @@ -1,3 +1,10 @@ +pub(crate) mod bmm; +pub(crate) mod relu; +pub(crate) mod requantize_bmm_float; +pub(crate) mod requantize_bmm_ref; +pub(crate) mod requantize_bmm_single; +pub(crate) mod reshape; + use ark_crypto_primitives::sponge::Absorb; use ark_ff::PrimeField; use ark_poly_commit::PolynomialCommitment; @@ -7,24 +14,18 @@ use crate::{ nodes::{bmm::BMMNode, relu::ReLUNode}, CryptographicSponge, Poly, }, - QArray, + Tensor, }; use self::{ bmm::{BMMNodeCommitment, BMMNodeCommitmentState, BMMNodeProof}, - requantise_bmm::{ - RequantiseBMMNode, RequantiseBMMNodeCommitment, RequantiseBMMNodeCommitmentState, - RequantiseBMMNodeProof, - }, + requantize_bmm_float::*, + requantize_bmm_ref::*, + requantize_bmm_single::*, reshape::ReshapeNode, }; -use super::qarray::{InnerType, QTypeArray}; - -pub(crate) mod bmm; -pub(crate) mod relu; -pub(crate) mod requantise_bmm; -pub(crate) mod reshape; +use super::tensor::{NIOTensor, SmallNIO}; // mod parser; @@ -47,7 +48,7 @@ pub trait NodeOpsNative { /// Evaluate the node natively (without padding) /// TODO decide whether this method should stay on `NodeOps`, or maybe go to `NodeOpsSNARKVerify` - fn evaluate(&self, input: &QArray) -> QArray; + fn evaluate(&self, input: &Tensor) -> Tensor; } pub trait NodeOpsPadded: NodeOpsNative { @@ -82,12 +83,14 @@ pub trait NodeOpsPadded: NodeOpsNative { fn com_num_vars(&self) -> usize; /// Evaluate the padded node natively - fn padded_evaluate(&self, input: &QArray) -> QArray; + fn padded_evaluate(&self, input: &Tensor) -> Tensor; } -pub enum Node { - BMM(BMMNode), - RequantiseBMM(RequantiseBMMNode), +pub enum Node { + BMM(BMMNode), + RequantizeBMMFloat(RequantizeBMMFloatNode), + RequantizeBMMRef(RequantizeBMMRefNode), + RequantizeBMMSingle(RequantizeBMMSingleNode), ReLU(ReLUNode), Reshape(ReshapeNode), } @@ -99,7 +102,9 @@ where PCS: PolynomialCommitment, S>, { BMM(BMMNodeProof), - RequantiseBMM(RequantiseBMMNodeProof), + RequantizeBMM(RequantizeBMMNodeProof), + RequantizeBMRef(RequantizeBMMRefNodeProof), + RequantizeBMMSingle(RequantizeBMMSingleNodeProof), ReLU(()), Reshape(()), } @@ -111,7 +116,9 @@ where PCS: PolynomialCommitment, S>, { BMM(BMMNodeCommitment), - RequantiseBMM(RequantiseBMMNodeCommitment), + RequantizeBMM(RequantizeBMMNodeCommitment), + RequantizeBMMRef(RequantizeBMMRefNodeCommitment), + RequantizeBMMSingle(RequantizeBMMSingleNodeCommitment), ReLU(()), Reshape(()), } @@ -123,24 +130,27 @@ where PCS: PolynomialCommitment, S>, { BMM(BMMNodeCommitmentState), - RequantiseBMM(RequantiseBMMNodeCommitmentState), + RequantizeBMM(RequantizeBMMNodeCommitmentState), + RequantizeBMMRef(RequantizeBMMRefNodeCommitmentState), + RequantizeBMMSingle(RequantizeBMMSingleNodeCommitmentState), ReLU(()), Reshape(()), } // A lot of this overlaps with the NodeOps trait and could be handled more // elegantly by simply implementing the trait -impl Node +impl Node where - I: InnerType + TryFrom, - O: InnerType + From, + ST: SmallNIO, { // Print the type of the node. This cannot be cleantly achieved by deriving // Debug pub fn type_name(&self) -> &'static str { match self { Node::BMM(_) => "BMM", - Node::RequantiseBMM(_r) => "RequantiseBMM", + Node::RequantizeBMMFloat(_r) => "RequantizeBMMFloat", + Node::RequantizeBMMRef(_r) => "RequantizeBMMRef", + Node::RequantizeBMMSingle(_r) => "RequantizeBMMSingle", Node::ReLU(_) => "ReLU", Node::Reshape(_) => "Reshape", } @@ -152,12 +162,14 @@ where } /// Evaluate the node natively (without padding) - pub fn evaluate(&self, input: &QTypeArray) -> QTypeArray { + pub fn evaluate(&self, input: &NIOTensor) -> NIOTensor { match (self, input) { - (Node::BMM(fc), QTypeArray::S(input)) => QTypeArray::L(fc.evaluate(input)), - (Node::RequantiseBMM(r), QTypeArray::L(input)) => QTypeArray::S(r.evaluate(input)), - (Node::ReLU(r), QTypeArray::S(input)) => QTypeArray::S(r.evaluate(input)), - (Node::Reshape(r), QTypeArray::S(input)) => QTypeArray::S(r.evaluate(input)), + (Node::BMM(fc), NIOTensor::S(input)) => NIOTensor::L(fc.evaluate(input)), + (Node::RequantizeBMMFloat(r), NIOTensor::L(input)) => NIOTensor::S(r.evaluate(input)), + (Node::RequantizeBMMRef(r), NIOTensor::L(input)) => NIOTensor::S(r.evaluate(input)), + (Node::RequantizeBMMSingle(r), NIOTensor::L(input)) => NIOTensor::S(r.evaluate(input)), + (Node::ReLU(r), NIOTensor::S(input)) => NIOTensor::S(r.evaluate(input)), + (Node::Reshape(r), NIOTensor::S(input)) => NIOTensor::S(r.evaluate(input)), _ => panic!( "Type mismatch: node of type {} received input of type {}", self.type_name(), @@ -173,14 +185,20 @@ where /// Here we perform matching without sanity checks. By design, the input type of the /// next node in the model is the same as the output type of the current node, /// so hiccups should never occur. - pub fn padded_evaluate(&self, input: &QTypeArray) -> QTypeArray { + pub fn padded_evaluate(&self, input: &NIOTensor) -> NIOTensor { match (self, input) { - (Node::BMM(fc), QTypeArray::S(input)) => QTypeArray::L(fc.padded_evaluate(input)), - (Node::RequantiseBMM(r), QTypeArray::L(input)) => { - QTypeArray::S(r.padded_evaluate(input)) + (Node::BMM(fc), NIOTensor::S(input)) => NIOTensor::L(fc.padded_evaluate(input)), + (Node::RequantizeBMMFloat(r), NIOTensor::L(input)) => { + NIOTensor::S(r.padded_evaluate(input)) + } + (Node::RequantizeBMMRef(r), NIOTensor::L(input)) => { + NIOTensor::S(r.padded_evaluate(input)) + } + (Node::RequantizeBMMSingle(r), NIOTensor::L(input)) => { + NIOTensor::S(r.padded_evaluate(input)) } - (Node::ReLU(r), QTypeArray::S(input)) => QTypeArray::S(r.padded_evaluate(input)), - (Node::Reshape(r), QTypeArray::S(input)) => QTypeArray::S(r.padded_evaluate(input)), + (Node::ReLU(r), NIOTensor::S(input)) => NIOTensor::S(r.padded_evaluate(input)), + (Node::Reshape(r), NIOTensor::S(input)) => NIOTensor::S(r.padded_evaluate(input)), _ => panic!("Invalid input type for node"), } } diff --git a/common/src/model/nodes/relu.rs b/common/src/model/nodes/relu.rs index 28a007b..a96181a 100644 --- a/common/src/model/nodes/relu.rs +++ b/common/src/model/nodes/relu.rs @@ -1,6 +1,6 @@ use ark_std::log2; -use crate::{model::qarray::InnerType, QArray}; +use crate::{model::tensor::SmallNIO, Tensor}; use super::{NodeOpsNative, NodeOpsPadded}; @@ -13,13 +13,13 @@ pub struct ReLUNode { impl NodeOpsNative for ReLUNode where - ST: InnerType, + ST: SmallNIO, { fn shape(&self) -> Vec { vec![self.num_units] } - fn evaluate(&self, input: &QArray) -> QArray { + fn evaluate(&self, input: &Tensor) -> Tensor { // TODO sanity checks (cf. BMM); systematise input.maximum(self.zero_point) } @@ -28,7 +28,7 @@ where // impl NodeOpsSnark impl NodeOpsPadded for ReLUNode where - ST: InnerType, + ST: SmallNIO, { fn padded_shape_log(&self) -> Vec { vec![self.log_num_units] @@ -40,7 +40,7 @@ where // TODO this is the same as evaluate() for now; the two will likely differ // if/when we introduce input size checks - fn padded_evaluate(&self, input: &QArray) -> QArray { + fn padded_evaluate(&self, input: &Tensor) -> Tensor { // TODO sanity checks (cf. BMM); systematise input.maximum(self.zero_point) } diff --git a/common/src/model/nodes/requantise_bmm.rs b/common/src/model/nodes/requantize_bmm_float.rs similarity index 66% rename from common/src/model/nodes/requantise_bmm.rs rename to common/src/model/nodes/requantize_bmm_float.rs index cd7a5b3..c0da5d5 100644 --- a/common/src/model/nodes/requantise_bmm.rs +++ b/common/src/model/nodes/requantize_bmm_float.rs @@ -1,15 +1,15 @@ use ark_std::log2; -use crate::model::qarray::{InnerType, QArray}; -use crate::quantization::{requantise_fc, BMMQInfo, QInfo, QScaleType, RoundingScheme}; +use crate::model::tensor::{SmallNIO, Tensor}; +use crate::quantization::{requantize_fc, BMMQInfo, QInfo, QScaleType, RoundingScheme}; use crate::{Commitment, CommitmentState}; use super::{NodeOpsNative, NodeOpsPadded}; // TODO convention: input, bias and output are rows, the op is vec-by-mat (in that order) -/// Apply requantisation after a BMM argument -pub struct RequantiseBMMNode { +/// Apply requantization after a BMM argument +pub struct RequantizeBMMFloatNode { // Number of units size: usize, @@ -20,44 +20,43 @@ pub struct RequantiseBMMNode { pub q_info: BMMQInfo, } -pub struct RequantiseBMMNodeCommitment(); +pub struct RequantizeBMMNodeCommitment(); -impl Commitment for RequantiseBMMNodeCommitment {} +impl Commitment for RequantizeBMMNodeCommitment {} -pub struct RequantiseBMMNodeCommitmentState(); +pub struct RequantizeBMMNodeCommitmentState(); -impl CommitmentState for RequantiseBMMNodeCommitmentState {} +impl CommitmentState for RequantizeBMMNodeCommitmentState {} -pub struct RequantiseBMMNodeProof { +pub struct RequantizeBMMNodeProof { // this will be the sumcheck proof } -impl NodeOpsNative for RequantiseBMMNode +impl NodeOpsNative for RequantizeBMMFloatNode where - ST: InnerType + TryFrom, - LT: InnerType + From, + ST: SmallNIO, { fn shape(&self) -> Vec { vec![self.size] } - fn evaluate(&self, input: &QArray) -> QArray { + fn evaluate(&self, input: &Tensor) -> Tensor { // Sanity checks // TODO systematise assert_eq!( input.num_dims(), 1, - "Incorrect shape: RequantiseBMM node expects a 1-dimensional input array" + "Incorrect shape: RequantizeBMM node expects a 1-dimensional input array" ); assert_eq!( self.size, input.len(), - "Length mismatch: RequantiseBMM node expects input with {} elements, got {} elements instead", + "Length mismatch: RequantizeBMM node expects input with {} elements, got {} elements instead", self.size, input.len() ); - let output: QArray = requantise_fc( + let output: Tensor = requantize_fc::( input.values(), &self.q_info, RoundingScheme::NearestTiesEven, @@ -68,10 +67,9 @@ where } } -impl NodeOpsPadded for RequantiseBMMNode +impl NodeOpsPadded for RequantizeBMMFloatNode where - ST: InnerType + TryFrom, - LT: InnerType + From, + ST: SmallNIO, { fn padded_shape_log(&self) -> Vec { vec![self.padded_size_log] @@ -81,7 +79,7 @@ where self.padded_size_log } - fn padded_evaluate(&self, input: &QArray) -> QArray { + fn padded_evaluate(&self, input: &Tensor) -> Tensor { let padded_size = 1 << self.padded_size_log; // Sanity checks @@ -89,18 +87,18 @@ where assert_eq!( input.num_dims(), 1, - "Incorrect shape: RequantiseBMM node expects a 1-dimensional input array" + "Incorrect shape: RequantizeBMMFloat node expects a 1-dimensional input array" ); assert_eq!( padded_size, input.len(), - "Length mismatch: Padded fully connected node expected input with {} elements, got {} elements instead", + "Length mismatch: Padded RequantizeBMMFloat node expected input with {} elements, got {} elements instead", padded_size, input.len() ); - let output: QArray = requantise_fc::( + let output: Tensor = requantize_fc::( input.values(), &self.q_info, RoundingScheme::NearestTiesEven, @@ -110,7 +108,7 @@ where } } -impl RequantiseBMMNode { +impl RequantizeBMMFloatNode { pub fn new( size: usize, s_i: QScaleType, diff --git a/common/src/model/nodes/requantize_bmm_ref.rs b/common/src/model/nodes/requantize_bmm_ref.rs new file mode 100644 index 0000000..5f9ac0c --- /dev/null +++ b/common/src/model/nodes/requantize_bmm_ref.rs @@ -0,0 +1,140 @@ +use ark_std::log2; + +use crate::model::tensor::{SmallNIO, Tensor}; +use crate::quantization::{quantize_multiplier, requantize_ref}; +use crate::{Commitment, CommitmentState}; + +use super::{NodeOpsNative, NodeOpsPadded}; + +// TODO convention: input, bias and output are rows, the op is vec-by-mat (in that order) + +/// Apply requantization after a BMM argument +pub struct RequantizeBMMRefNode { + // Number of units + size: usize, + + // log2 of the number of units + pub padded_size_log: usize, + + // Represents a non-negative right shift reduced by the implicit + // fixed-point-arithmetic shift in DoublingHighRoundMultiply + effective_shift: usize, + + // + effective_multiplier: ST::LT, + + // + output_zero_point: ST, +} + +pub struct RequantizeBMMRefNodeCommitment(); + +impl Commitment for RequantizeBMMRefNodeCommitment {} + +pub struct RequantizeBMMRefNodeCommitmentState(); + +impl CommitmentState for RequantizeBMMRefNodeCommitmentState {} + +pub struct RequantizeBMMRefNodeProof { + // this will be the sumcheck proof +} + +impl NodeOpsNative for RequantizeBMMRefNode +where + ST: SmallNIO, +{ + fn shape(&self) -> Vec { + vec![self.size] + } + + fn evaluate(&self, input: &Tensor) -> Tensor { + // Sanity checks + // TODO systematise + assert_eq!( + input.num_dims(), + 1, + "Incorrect shape: RequantizeBMMRef node expects a 1-dimensional input array" + ); + assert_eq!( + self.size, + input.len(), + "Length mismatch: RequantizeBMMRef node expects input with {} elements, got {} elements instead", + self.size, + input.len() + ); + + let output: Tensor = requantize_ref::( + input.values(), + self.effective_multiplier, + self.effective_shift, + self.output_zero_point, + ) + .into(); + + output + } +} + +impl NodeOpsPadded for RequantizeBMMRefNode +where + ST: SmallNIO, +{ + fn padded_shape_log(&self) -> Vec { + vec![self.padded_size_log] + } + + fn com_num_vars(&self) -> usize { + self.padded_size_log + } + + fn padded_evaluate(&self, input: &Tensor) -> Tensor { + let padded_size = 1 << self.padded_size_log; + + // Sanity checks + // TODO systematise + assert_eq!( + input.num_dims(), + 1, + "Incorrect shape: RequantizeBMMRef node expects a 1-dimensional input array" + ); + + assert_eq!( + padded_size, + input.len(), + "Length mismatch: Padded RequantizeBMMRef node expected input with {} elements, got {} elements instead", + padded_size, + input.len() + ); + + let output: Tensor = requantize_ref::( + input.values(), + self.effective_multiplier, + self.effective_shift, + self.output_zero_point, + ) + .into(); + output + } +} + +impl RequantizeBMMRefNode { + pub fn new(size: usize, s_i: f32, s_w: f32, s_o: f32, z_o: i8) -> Self { + let padded_size_log = log2(size.next_power_of_two()) as usize; + + // cast scales to a type with higher precision + let (s_i, s_w, s_o) = (s_i as f64, s_w as f64, s_o as f64); + let double_multiplier = s_i * s_w / s_o; + + // compute full shift and effective multiplier + let (effective_multiplier, effective_shift) = quantize_multiplier(double_multiplier); + + Self { + size, + padded_size_log, + effective_shift, + effective_multiplier, + output_zero_point: z_o, + } + } +} +// TODO in constructor, add quantisation information checks? (e.g. z_weight = 0, etc.) diff --git a/common/src/model/nodes/requantize_bmm_single.rs b/common/src/model/nodes/requantize_bmm_single.rs new file mode 100644 index 0000000..101d783 --- /dev/null +++ b/common/src/model/nodes/requantize_bmm_single.rs @@ -0,0 +1,139 @@ +use ark_std::log2; + +use crate::model::tensor::{SmallNIO, Tensor}; +use crate::quantization::{quantize_multiplier, requantize_single_round}; +use crate::{Commitment, CommitmentState}; + +use super::{NodeOpsNative, NodeOpsPadded}; + +// TODO convention: input, bias and output are rows, the op is vec-by-mat (in that order) + +/// Apply requantization after a BMM argument +pub struct RequantizeBMMSingleNode { + // Number of units + size: usize, + + // log2 of the number of units + pub padded_size_log: usize, + + // Represents a non-negative effective right shift + full_shift: usize, + + // + effective_multiplier: ST::LT, + + // + output_zero_point: ST, +} + +pub struct RequantizeBMMSingleNodeCommitment(); + +impl Commitment for RequantizeBMMSingleNodeCommitment {} + +pub struct RequantizeBMMSingleNodeCommitmentState(); + +impl CommitmentState for RequantizeBMMSingleNodeCommitmentState {} + +pub struct RequantizeBMMSingleNodeProof { + // this will be the sumcheck proof +} + +impl NodeOpsNative for RequantizeBMMSingleNode +where + ST: SmallNIO, +{ + fn shape(&self) -> Vec { + vec![self.size] + } + + fn evaluate(&self, input: &Tensor) -> Tensor { + // Sanity checks + // TODO systematise + assert_eq!( + input.num_dims(), + 1, + "Incorrect shape: RequantizeBMMSingle node expects a 1-dimensional input array" + ); + assert_eq!( + self.size, + input.len(), + "Length mismatch: RequantizeBMMSingle node expects input with {} elements, got {} elements instead", + self.size, + input.len() + ); + + let output: Tensor = requantize_single_round::( + input.values(), + self.effective_multiplier, + self.full_shift, + self.output_zero_point, + ) + .into(); + + output + } +} + +impl NodeOpsPadded for RequantizeBMMSingleNode +where + ST: SmallNIO, +{ + fn padded_shape_log(&self) -> Vec { + vec![self.padded_size_log] + } + + fn com_num_vars(&self) -> usize { + self.padded_size_log + } + + fn padded_evaluate(&self, input: &Tensor) -> Tensor { + let padded_size = 1 << self.padded_size_log; + + // Sanity checks + // TODO systematise + assert_eq!( + input.num_dims(), + 1, + "Incorrect shape: RequantizeBMMSingle node expects a 1-dimensional input array" + ); + + assert_eq!( + padded_size, + input.len(), + "Length mismatch: Padded RequantizeBMMSingle node expected input with {} elements, got {} elements instead", + padded_size, + input.len() + ); + + let output: Tensor = requantize_single_round::( + input.values(), + self.effective_multiplier, + self.full_shift, + self.output_zero_point, + ) + .into(); + output + } +} + +impl RequantizeBMMSingleNode { + pub fn new(size: usize, s_i: f32, s_w: f32, s_o: f32, z_o: i8) -> Self { + let padded_size_log = log2(size.next_power_of_two()) as usize; + + // cast scales to a type with higher precision + let (s_i, s_w, s_o) = (s_i as f64, s_w as f64, s_o as f64); + let double_multiplier = s_i * s_w / s_o; + + // compute full shift and effective multiplier + let (effective_multiplier, effective_shift) = quantize_multiplier(double_multiplier); + + Self { + size, + padded_size_log, + full_shift: effective_shift + (i32::BITS - 1) as usize, + effective_multiplier, + output_zero_point: z_o, + } + } +} +// TODO in constructor, add quantisation information checks? (e.g. z_weight = 0, etc.) diff --git a/common/src/model/nodes/reshape.rs b/common/src/model/nodes/reshape.rs index b532135..691867d 100644 --- a/common/src/model/nodes/reshape.rs +++ b/common/src/model/nodes/reshape.rs @@ -1,6 +1,6 @@ use ark_std::log2; -use crate::{model::qarray::InnerType, QArray}; +use crate::{model::tensor::SmallNIO, Tensor}; use super::{NodeOpsNative, NodeOpsPadded}; @@ -13,13 +13,13 @@ pub struct ReshapeNode { impl NodeOpsNative for ReshapeNode where - ST: InnerType, + ST: SmallNIO, { fn shape(&self) -> Vec { self.output_shape.clone() } - fn evaluate(&self, input: &QArray) -> QArray { + fn evaluate(&self, input: &Tensor) -> Tensor { // Sanity checks // TODO systematise @@ -37,7 +37,7 @@ where impl NodeOpsPadded for ReshapeNode where - ST: InnerType, + ST: SmallNIO, { fn padded_shape_log(&self) -> Vec { self.padded_output_shape_log.clone() @@ -49,7 +49,7 @@ where // TODO I think this might be broken due to the failure of commutativity // between product and and nearest-geq-power-of-two - fn padded_evaluate(&self, input: &QArray) -> QArray { + fn padded_evaluate(&self, input: &Tensor) -> Tensor { let padded_input_shape: Vec = self .padded_input_shape_log .iter() diff --git a/common/src/model/qarray/mod.rs b/common/src/model/tensor/mod.rs similarity index 71% rename from common/src/model/qarray/mod.rs rename to common/src/model/tensor/mod.rs index e49d50d..a141306 100644 --- a/common/src/model/qarray/mod.rs +++ b/common/src/model/tensor/mod.rs @@ -1,11 +1,11 @@ -use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign}; - -use ark_std::ops::Index; +use std::ops::{AddAssign, BitAnd, DivAssign, MulAssign, Shl, Shr, SubAssign}; use ark_std::any::type_name; use ark_std::cmp::PartialOrd; use ark_std::fmt; use ark_std::fmt::Debug; +use ark_std::mem; +use ark_std::ops::Index; use ark_std::ops::{Add, Div, Mul, Sub}; use ark_std::vec; use ark_std::vec::Vec; @@ -18,10 +18,12 @@ use crate::quantization::QScaleType; #[cfg(test)] mod tests; -const QARRAY_NESTED_TAB: &str = " "; +const TENSOR_NESTED_TAB: &str = " "; -pub trait InnerType: +pub trait Integral: Copy + + Serialize + + DeserializeOwned + Debug + PartialEq + PartialOrd @@ -33,75 +35,72 @@ pub trait InnerType: + SubAssign + MulAssign + DivAssign - + Serialize - + DeserializeOwned + + Shl + + Shr + + BitAnd + + Into { + // We can't simply require Double: Integral, as that would create an + // infinite chain + type Double: Copy + + Debug + + TryInto + + Mul + + Div + + Add + + Sub + + Shl + + Shr; + const ZERO: Self; + const ONE: Self; + const ONE_DOUBLE: Self::Double; const MIN: Self; const MAX: Self; + const BITS: usize; - // TODO if we decide to make the model generic on the quantisation process - // types, this will change + // TODO this should be removed once floating requantisation is made generic fn from_qscaletype(x: QScaleType) -> Self; fn to_qscaletype(&self) -> QScaleType; } -impl InnerType for i8 { - const ZERO: Self = 0; - const MIN: Self = Self::MIN; - const MAX: Self = Self::MAX; - - fn from_qscaletype(x: QScaleType) -> Self { - x as Self - } - - fn to_qscaletype(&self) -> QScaleType { - *self as QScaleType - } +#[macro_export] +macro_rules! impl_integral { + ( $t1:ty, $t2:ty ) => { + impl Integral for $t1 { + type Double = $t2; + + const ZERO: Self = 0; + const ONE: Self = 1; + const ONE_DOUBLE: Self::Double = 1; + const MIN: Self = Self::MIN; + const MAX: Self = Self::MAX; + const BITS: usize = 8 * mem::size_of::(); + + fn from_qscaletype(x: QScaleType) -> Self { + x as Self + } + + fn to_qscaletype(&self) -> QScaleType { + *self as QScaleType + } + } + }; } -impl InnerType for i32 { - const ZERO: Self = 0; - const MIN: Self = Self::MIN; - const MAX: Self = Self::MAX; - fn from_qscaletype(x: QScaleType) -> Self { - x as Self - } +impl_integral!(i8, i16); +impl_integral!(i32, i64); - fn to_qscaletype(&self) -> QScaleType { - *self as QScaleType - } +pub trait SmallNIO: Integral + Into { + type LT: Integral + TryInto; } -impl InnerType for u8 { - const ZERO: Self = 0; - const MIN: Self = Self::MIN; - const MAX: Self = Self::MAX; - - fn from_qscaletype(x: QScaleType) -> Self { - x as Self - } - fn to_qscaletype(&self) -> QScaleType { - *self as QScaleType - } -} - -impl InnerType for f32 { - const ZERO: Self = 0.0; - const MIN: Self = Self::MIN; - const MAX: Self = Self::MAX; - - fn from_qscaletype(x: QScaleType) -> Self { - x as Self - } - - fn to_qscaletype(&self) -> QScaleType { - *self as QScaleType - } +impl SmallNIO for i8 { + type LT = i32; } #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -pub struct QArray { +pub struct Tensor { #[serde(rename = "f")] flattened: Vec, #[serde(rename = "s")] @@ -111,13 +110,13 @@ pub struct QArray { } #[derive(Clone)] -pub enum QTypeArray { - S(QArray), - L(QArray), +pub enum NIOTensor { + S(Tensor), + L(Tensor), } -// impl indexing into the QArray -impl Index for QArray { +// indexing syntax tensor[idx] for Tensor +impl Index for Tensor { type Output = T; fn index(&self, index: usize) -> &Self::Output { @@ -125,7 +124,7 @@ impl Index for QArray { } } -impl QArray { +impl Tensor { pub fn check_dimensions(&self) -> bool { self.flattened.len() == self.shape.iter().product::() } @@ -152,24 +151,7 @@ impl QArray { self.flattened } - // TODO in the future, if necessary, we can remove the bound - // >::Error: Debug - // and replace unwrap() by unwrap_or(), possibly panicking or propagating - // the error - pub fn cast(&self) -> QArray - where - T: TryInto, - >::Error: Debug, - { - let flattened = self - .flattened - .iter() - .map(|x| TryInto::::try_into(*x).unwrap()) - .collect(); - QArray::new(flattened, self.shape.clone()) - } - - // Reshapes the QArray in-place + // Reshapes the Tensor in-place pub fn reshape(&mut self, new_shape: Vec) { assert_eq!( self.len(), @@ -240,16 +222,46 @@ impl QArray { .map(|(i, d)| i * d) .sum() } +} - #[allow(dead_code)] - pub(crate) fn get(&self, index: Vec) -> T { - self.flattened[self.flatten_index(index)] +/********************** Serialization **********************/ +impl Tensor { + pub fn write(&self, path: &str) { + let mut writer = std::fs::File::create(path).unwrap(); + serde_json::to_writer(&mut writer, self).unwrap(); + } + + pub fn read(path: &str) -> Tensor { + let reader = std::fs::File::open(path).unwrap(); + serde_json::from_reader(reader).unwrap() } - /// For each dimension of self.shape, either pad the QArray with `value` + pub fn write_multiple(tensors: &[&Tensor], paths: &[&str]) { + for (tensor, path) in tensors.iter().zip(paths.iter()) { + tensor.write(path); + } + } + + pub fn read_multiple(paths: &[&str]) -> Vec> { + paths.iter().map(|path| Tensor::read(path)).collect() + } + + pub fn write_list(tensors: &[&Tensor], path: &str) { + let mut writer = std::fs::File::create(path).unwrap(); + serde_json::to_writer(&mut writer, tensors).unwrap(); + } + + pub fn read_list(path: &str) -> Vec> { + let reader = std::fs::File::open(path).unwrap(); + serde_json::from_reader(reader).unwrap() + } +} + +impl Tensor { + /// For each dimension of self.shape, either pad the Tensor with `value` /// (if the new size is larger than the original one) or truncate it (if /// the new size is smaller than or equal to the original one). - pub fn compact_resize(&self, new_shape: Vec, value: T) -> QArray { + pub fn compact_resize(&self, new_shape: Vec, value: T) -> Tensor { let old_shape = &self.shape; assert_eq!( @@ -260,7 +272,7 @@ impl QArray { old_shape.len(), ); - // compute cumulative dimensions of the qarray + // compute cumulative dimensions of the tensor let mut new_cumulative_dimensions = Vec::with_capacity(new_shape.len()); let mut acc = 1; @@ -281,37 +293,25 @@ impl QArray { value, ); - QArray::new(flattened, new_shape) - } - - pub fn write(&self, path: &str) { - let mut writer = std::fs::File::create(path).unwrap(); - serde_json::to_writer(&mut writer, self).unwrap(); - } - - pub fn read(path: &str) -> QArray { - let reader = std::fs::File::open(path).unwrap(); - serde_json::from_reader(reader).unwrap() + Tensor::new(flattened, new_shape) } - pub fn write_multiple(qarrays: &[&QArray], paths: &[&str]) { - for (qarray, path) in qarrays.iter().zip(paths.iter()) { - qarray.write(path); - } - } - - pub fn read_multiple(paths: &[&str]) -> Vec> { - paths.iter().map(|path| QArray::read(path)).collect() - } - - pub fn write_list(qarrays: &[&QArray], path: &str) { - let mut writer = std::fs::File::create(path).unwrap(); - serde_json::to_writer(&mut writer, qarrays).unwrap(); + #[allow(dead_code)] + pub(crate) fn get(&self, index: Vec) -> T { + self.flattened[self.flatten_index(index)] } - pub fn read_list(path: &str) -> Vec> { - let reader = std::fs::File::open(path).unwrap(); - serde_json::from_reader(reader).unwrap() + pub fn cast(&self) -> Tensor + where + T: TryInto, + >::Error: Debug, + { + let flattened = self + .flattened + .iter() + .map(|x| TryInto::::try_into(*x).unwrap()) + .collect(); + Tensor::new(flattened, self.shape.clone()) } } @@ -380,74 +380,72 @@ fn compact_resize_internal( } /************************ Operators ************************/ - -// Since numerical type control is essential, we implement only QArray + T -// insead of the more general QArray + S for any S which can be added to T, +// Since numerical type control is essential, we implement only Tensor + T +// insead of the more general Tensor + S for any S which can be added to T, // thus forcing the programmer to make intentional casts. The same applies to // other operator implementations below. -impl Add for QArray +impl Add for Tensor where T: Add, { - type Output = QArray; + type Output = Tensor; - fn add(self, rhs: T) -> QArray { + fn add(self, rhs: T) -> Tensor { let flattened = self.flattened.into_iter().map(|x| x + rhs).collect(); - QArray::new(flattened, self.shape) + Tensor::new(flattened, self.shape) } } // Addition in the other direction cannot be implemented in the same way, cf. // https://stackoverflow.com/questions/70220168/how-to-implement-mul-trait-for-a-custom-struct-type-to-work-in-both-ways // There is a workaround, but it is not necessary for now -// impl ops::Add> for T where T: ops::Add +// impl ops::Add> for T where T: ops::Add -impl Sub for QArray +impl Sub for Tensor where T: Sub, { - type Output = QArray; + type Output = Tensor; - fn sub(self, rhs: T) -> QArray { + fn sub(self, rhs: T) -> Tensor { let flattened = self.flattened.into_iter().map(|x| x - rhs).collect(); - QArray::new(flattened, self.shape) + Tensor::new(flattened, self.shape) } } -impl Mul for QArray +impl Mul for Tensor where T: Mul, { - type Output = QArray; + type Output = Tensor; - fn mul(self, rhs: T) -> QArray { + fn mul(self, rhs: T) -> Tensor { let flattened = self.flattened.into_iter().map(|x| x * rhs).collect(); - QArray::new(flattened, self.shape) + Tensor::new(flattened, self.shape) } } -impl Div for QArray +impl Div for Tensor where T: Div, { - type Output = QArray; + type Output = Tensor; - fn div(self, rhs: T) -> QArray { + fn div(self, rhs: T) -> Tensor { let flattened = self.flattened.into_iter().map(|x| x / rhs).collect(); - QArray::new(flattened, self.shape) + Tensor::new(flattened, self.shape) } } /******************* Conversion from Vec *******************/ - -impl From> for QArray { +impl From> for Tensor { fn from(values: Vec) -> Self { let l = values.len(); - QArray::new(values, vec![l]) + Tensor::new(values, vec![l]) } } -impl From>> for QArray { +impl From>> for Tensor { fn from(values: Vec>) -> Self { assert!( values.iter().all(|x| x.len() == values[0].len()), @@ -457,17 +455,16 @@ impl From>> for QArray { let shape = vec![values.len(), values[0].len()]; let flattened = values.into_iter().flatten().collect(); - QArray::new(flattened, shape) + Tensor::new(flattened, shape) } } /************************* Display *************************/ - -impl fmt::Display for QArray { +impl fmt::Display for Tensor { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "QArray ({}). Shape: {:?}. Data:", + "Tensor ({}). Shape: {:?}. Data:", type_name::(), self.shape )?; @@ -488,7 +485,7 @@ impl fmt::Display for QArray { } } -fn print_flat_data( +fn print_flat_data( f: &mut fmt::Formatter, data: &[T], cumulative_dimensions: &[usize], @@ -504,13 +501,13 @@ fn print_flat_data( return writeln!( f, "{}{:?}", - QARRAY_NESTED_TAB.repeat(original_len - 1), + TENSOR_NESTED_TAB.repeat(original_len - 1), data ); } if len != original_len { - writeln!(f, "{}[", QARRAY_NESTED_TAB.repeat(original_len - len))?; + writeln!(f, "{}[", TENSOR_NESTED_TAB.repeat(original_len - len))?; } let subarrays = data.chunks_exact(cumulative_dimensions[0]); @@ -526,44 +523,43 @@ fn print_flat_data( } if len != original_len { - writeln!(f, "{}]", QARRAY_NESTED_TAB.repeat(original_len - len))?; + writeln!(f, "{}]", TENSOR_NESTED_TAB.repeat(original_len - len))?; } Ok(()) } /*********************** Comparisons ***********************/ - // We follow the convention (e.g. in numpy) that `maximum` and `minimum` // compare an array to a single element (element-wise); whereas `max` and `min` // (not implemented) compare two equally sized arrays element-wise. -impl QArray { - pub fn maximum(&self, x: T) -> QArray { +impl Tensor { + pub fn maximum(&self, x: T) -> Tensor { let flattened_max: Vec = self .flattened .iter() .map(|y| if *y >= x { *y } else { x }) .collect(); - // Construct the new QArray directly to avoid recomputation of + // Construct the new Tensor directly to avoid recomputation of // cumulative dimensions - QArray { + Tensor { flattened: flattened_max, shape: self.shape.clone(), cumulative_dimensions: self.cumulative_dimensions.clone(), } } - pub fn minimum(&self, x: T) -> QArray { + pub fn minimum(&self, x: T) -> Tensor { let flattened_min: Vec = self .flattened .iter() .map(|y| if *y <= x { *y } else { x }) .collect(); - // Construct the new QArray directly to avoid recomputation of + // Construct the new Tensor directly to avoid recomputation of // cumulative dimensions - QArray { + Tensor { flattened: flattened_min, shape: self.shape.clone(), cumulative_dimensions: self.cumulative_dimensions.clone(), @@ -571,37 +567,36 @@ impl QArray { } } -/************************ QTypeArray ***********************/ - -impl QTypeArray { +/************************ NIOTensor ***********************/ +impl NIOTensor { #[inline] - pub fn unwrap_small(self) -> QArray { + pub fn unwrap_small(self) -> Tensor { match self { - QTypeArray::S(s) => s, + NIOTensor::S(s) => s, _ => panic!("Expected S variant"), } } #[inline] - pub fn unwrap_large(self) -> QArray { + pub fn unwrap_large(self) -> Tensor { match self { - QTypeArray::L(l) => l, + NIOTensor::L(l) => l, _ => panic!("Expected L variant"), } } #[inline] - pub fn ref_small(&self) -> &QArray { + pub fn ref_small(&self) -> &Tensor { match self { - QTypeArray::S(s) => s, + NIOTensor::S(s) => s, _ => panic!("Expected S variant"), } } #[inline] - pub fn ref_large(&self) -> &QArray { + pub fn ref_large(&self) -> &Tensor { match self { - QTypeArray::L(l) => l, + NIOTensor::L(l) => l, _ => panic!("Expected L variant"), } } @@ -609,8 +604,8 @@ impl QTypeArray { #[inline] pub fn variant_name(&self) -> &'static str { match self { - QTypeArray::S(_) => "QTypeArray::S", - QTypeArray::L(_) => "QTypeArray::L", + NIOTensor::S(_) => "NIOTensor::S", + NIOTensor::L(_) => "NIOTensor::L", } } } diff --git a/common/src/model/qarray/tests.rs b/common/src/model/tensor/tests.rs similarity index 73% rename from common/src/model/qarray/tests.rs rename to common/src/model/tensor/tests.rs index ac17a5b..ef97a9f 100644 --- a/common/src/model/qarray/tests.rs +++ b/common/src/model/tensor/tests.rs @@ -2,7 +2,7 @@ use super::*; #[test] fn test_flatten_index_trivial() { - let q = QArray::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]); + let q = Tensor::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]); for i in 0..9 { assert_eq!(q.flatten_index(vec![i]), i); } @@ -12,7 +12,7 @@ fn test_flatten_index_trivial() { fn test_flatten_index_2d() { let shape = vec![3, 3]; let flattened: Vec = (1..=9).collect(); - let q = QArray::new(flattened, shape); + let q = Tensor::new(flattened, shape); assert_eq!(q.flatten_index(vec![0, 0]), 0); assert_eq!(q.flatten_index(vec![0, 1]), 1); assert_eq!(q.flatten_index(vec![0, 2]), 2); @@ -26,7 +26,7 @@ fn test_flatten_index_2d() { fn test_flatten_index_3d() { let shape = vec![2, 3, 4]; let flattened: Vec = (1..=24).collect(); - let q = QArray::new(flattened, shape); + let q = Tensor::new(flattened, shape); assert_eq!(q.flatten_index(vec![0, 0, 0]), 0); assert_eq!(q.flatten_index(vec![0, 0, 1]), 1); assert_eq!(q.flatten_index(vec![0, 0, 2]), 2); @@ -39,9 +39,9 @@ fn test_flatten_index_3d() { fn test_resize_1d_pad() { let shape = vec![5]; let flattened: Vec = (1..=5).collect(); - let qarray = QArray::new(flattened, shape); + let tensor = Tensor::new(flattened, shape); - let padded = qarray.compact_resize(vec![7], 0); + let padded = tensor.compact_resize(vec![7], 0); assert_eq!(padded.shape, vec![7]); assert_eq!(padded.flattened, vec![1, 2, 3, 4, 5, 0, 0]); @@ -51,9 +51,9 @@ fn test_resize_1d_pad() { fn test_resize_2d_pad() { let shape = vec![2, 3]; let flattened: Vec = (1..=6).collect(); - let qarray = QArray::new(flattened, shape); + let tensor = Tensor::new(flattened, shape); - let padded = qarray.compact_resize(vec![5, 4], 0); + let padded = tensor.compact_resize(vec![5, 4], 0); let expected = vec![1, 2, 3, 0, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; @@ -65,9 +65,9 @@ fn test_resize_2d_pad() { fn test_resize_3d_pad() { let shape = vec![2, 3, 4]; let flattened: Vec = (1..=24).collect(); - let qarray = QArray::new(flattened, shape); + let tensor = Tensor::new(flattened, shape); - let padded = qarray.compact_resize(vec![3, 5, 7], 0); + let padded = tensor.compact_resize(vec![3, 5, 7], 0); let expected = vec![ 1, 2, 3, 4, 0, 0, 0, 5, 6, 7, 8, 0, 0, 0, 9, 10, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -84,11 +84,11 @@ fn test_resize_3d_pad() { fn test_resize_1d_truncate() { let shape = vec![7]; let original = vec![1, 2, 3, 4, 5, 0, 0]; - let qarray = QArray::new(original, shape); + let tensor = Tensor::new(original, shape); let expected: Vec = (1..=5).collect(); - let padded = qarray.compact_resize(vec![5], 0); + let padded = tensor.compact_resize(vec![5], 0); assert_eq!(padded.shape, vec![5]); assert_eq!(padded.flattened, expected); @@ -98,9 +98,9 @@ fn test_resize_1d_truncate() { fn test_resize_2d_truncate() { let shape = vec![5, 4]; let original = vec![1, 2, 3, 0, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - let qarray = QArray::new(original, shape); + let tensor = Tensor::new(original, shape); - let padded = qarray.compact_resize(vec![2, 3], 0); + let padded = tensor.compact_resize(vec![2, 3], 0); let expected: Vec = (1..=6).collect(); @@ -117,9 +117,9 @@ fn test_resize_3d_truncate() { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]; - let qarray = QArray::new(original, shape); + let tensor = Tensor::new(original, shape); - let padded = qarray.compact_resize(vec![2, 3, 4], 0); + let padded = tensor.compact_resize(vec![2, 3, 4], 0); let expected: Vec = (1..=24).collect(); @@ -131,9 +131,9 @@ fn test_resize_3d_truncate() { fn test_resize_3d_mixed() { let shape = vec![2, 2, 3]; let flattened: Vec = (1..=12).collect(); - let qarray = QArray::new(flattened, shape); + let tensor = Tensor::new(flattened, shape); - let padded = qarray.compact_resize(vec![3, 1, 5], 0); + let padded = tensor.compact_resize(vec![3, 1, 5], 0); let expected = vec![1, 2, 3, 0, 0, 7, 8, 9, 0, 0, 0, 0, 0, 0, 0]; @@ -143,48 +143,48 @@ fn test_resize_3d_mixed() { #[test] fn test_print_1d() { - println!("{}", QArray::from(vec![1, 2, 3, 4, 5])); + println!("{}", Tensor::from(vec![1, 2, 3, 4, 5])); } #[test] fn test_print_2d_1() { - println!("{}", QArray::new((1..=6).collect(), vec![2, 3])); + println!("{}", Tensor::new((1..=6).collect(), vec![2, 3])); } #[test] fn test_print_2d_2() { - println!("{}", QArray::new((1..=6).collect(), vec![3, 2])); + println!("{}", Tensor::new((1..=6).collect(), vec![3, 2])); } #[test] fn test_print_3d_1() { - println!("{}", QArray::new((1..=24).collect(), vec![2, 3, 4])); + println!("{}", Tensor::new((1..=24).collect(), vec![2, 3, 4])); } #[test] fn test_print_3d_2() { - println!("{}", QArray::new((1..=24).collect(), vec![3, 2, 4])); + println!("{}", Tensor::new((1..=24).collect(), vec![3, 2, 4])); } #[test] fn test_print_4d() { - println!("{}", QArray::new((1..=24).collect(), vec![2, 3, 2, 2])); + println!("{}", Tensor::new((1..=24).collect(), vec![2, 3, 2, 2])); } #[test] fn test_maximum() { - let qarray: QArray = QArray::new(vec![-1, 2, 3, 4, -5, 6], vec![2, 3]); + let tensor: Tensor = Tensor::new(vec![-1, 2, 3, 4, -5, 6], vec![2, 3]); // TODO the call to move_values will change once other branches are merged // Do it here and elsewhere - assert_eq!(qarray.maximum(3).move_values(), vec![3, 3, 3, 4, 3, 6]); + assert_eq!(tensor.maximum(3).move_values(), vec![3, 3, 3, 4, 3, 6]); } #[test] fn test_minimum() { - let qarray: QArray = QArray::new(vec![-1, 2, 3, 4, -5, 6], vec![2, 3]); + let tensor: Tensor = Tensor::new(vec![-1, 2, 3, 4, -5, 6], vec![2, 3]); // TODO the call to move_values will change once other branches are merged // Do it here and elsewhere - assert_eq!(qarray.minimum(3).move_values(), vec![-1, 2, 3, 3, -5, 3]); + assert_eq!(tensor.minimum(3).move_values(), vec![-1, 2, 3, 3, -5, 3]); } diff --git a/common/src/quantization.rs b/common/src/quantization.rs deleted file mode 100644 index 2b9e3a7..0000000 --- a/common/src/quantization.rs +++ /dev/null @@ -1,189 +0,0 @@ -use crate::model::qarray::InnerType; - -// TODO if we decide to make the model generic on the quantisation process -// types (which is probably correct now that the qtypes are generics), these -// will go away -// Type for quantisation scales -pub(crate) type QScaleType = f32; -// Larger precision type to compute the requantisation scale in some schemes -pub(crate) type QScaleComputationType = f64; - -pub struct QInfo { - pub scale: QScaleType, - pub zero_point: ST, -} - -// TODO: this will probably change to inference-ready requantisation info -// Even what is being done now could be optimised by precomputing outside the -// evaluate function -pub struct BMMQInfo { - pub input_info: QInfo, - pub weight_info: QInfo, - // Bias requantisation information is not used (and is indeed directly - // computable from the two above) - pub output_info: QInfo, -} - -pub enum RoundingScheme { - NearestTiesAwayFromZero, - NearestTiesEven, -} - -pub fn requantise_fc( - output: &[LT], - q_info: &BMMQInfo, - scheme: RoundingScheme, -) -> Vec -where - ST: InnerType + TryFrom, - LT: InnerType + From, -{ - match scheme { - RoundingScheme::NearestTiesAwayFromZero => requantise_fc_ntafz::(output, q_info), - RoundingScheme::NearestTiesEven => requantise_fc_nte::(output, q_info), - } -} - -fn requantise_fc_ntafz(output: &[LT], q_info: &BMMQInfo) -> Vec -where - ST: InnerType + TryFrom, - LT: InnerType + From, -{ - // 1. Computing scale - // TODO In actual schemes, this will be decomposed as (int, shift) - let (s_i, s_w, s_o) = ( - q_info.input_info.scale, - q_info.weight_info.scale, - q_info.output_info.scale, - ); - let (s_i, s_w, s_o) = ( - s_i as QScaleComputationType, - s_w as QScaleComputationType, - s_o as QScaleComputationType, - ); - let s = (s_i * s_w / s_o) as QScaleType; - - // 2. Requantise - // TODO add rayon for parallelisation? - output - .iter() - .map(|x| { - let x = LT::to_qscaletype(x) * s; - let mut x = LT::from_qscaletype(x.round()); - x += LT::from(q_info.output_info.zero_point); - ST::try_from(partial_ord_clamp(x, LT::from(ST::MIN), LT::from(ST::MAX))) - .map_err(|_| "Unable to convert Large Type to Small Type") - .unwrap() - }) - .collect() -} - -// The (unstable) method clamp comes from the trait Ord, which we cannot -// restrict InnerType to as we need f32 to implement the latter. Note that this -// method is not meaningfully defined for classes that genuinely do not -// implement Ord (total order relation) but only PartialOrd (partial order -// relation). -fn partial_ord_clamp(x: T, min: T, max: T) -> T { - if x <= min { - min - } else if x >= max { - max - } else { - x - } -} - -fn requantise_fc_nte(output: &[LT], q_info: &BMMQInfo) -> Vec -where - ST: InnerType + TryFrom, - LT: InnerType + From, -{ - // 1. Computing scale - // TODO In actual schemes, this will be decomposed as (int, shift) - let (s_i, s_w, s_o) = ( - q_info.input_info.scale, - q_info.weight_info.scale, - q_info.output_info.scale, - ); - let (s_i, s_w, s_o) = ( - s_i as QScaleComputationType, - s_w as QScaleComputationType, - s_o as QScaleComputationType, - ); - let s = (s_i * s_w / s_o) as QScaleType; - - // 2. Requantise - // TODO add rayon for parallelisation? - output - .iter() - .map(|x| { - let x = LT::to_qscaletype(x) * s; - let mut x = LT::from_qscaletype(x.round_ties_even()); // TODO which type to pick here? Should we check for overflows? - x += LT::from(q_info.output_info.zero_point); - ST::try_from(partial_ord_clamp(x, LT::from(ST::MIN), LT::from(ST::MAX))) - .map_err(|_| "Unable to convert Large Type to Small Type") - .unwrap() - }) - .collect() -} - -// This function is used to quantise model model inputs and its types are fixed -pub fn quantise_f32_u8_nne(values: &[f32], scale: QScaleType, zero: u8) -> Vec { - values - .iter() - .map(|x| { - ((((*x as QScaleType) / scale) + (zero as f32)).round_ties_even() as i32) - .clamp(u8::MIN as i32, u8::MAX as i32) as u8 - }) - .collect() -} - -#[cfg(test)] -mod tests { - - use super::*; - #[test] - fn test_nnafz_noop() { - let output = vec![0, 1, 2, 3, 4, 5, 6, 7]; - let q_info = BMMQInfo { - input_info: QInfo { - scale: 1.0, - zero_point: 0, - }, - weight_info: QInfo { - scale: 1.0, - zero_point: 0, - }, - output_info: QInfo { - scale: 1.0, - zero_point: 0, - }, - }; - let expected = vec![0, 1, 2, 3, 4, 5, 6, 7]; - let actual = requantise_fc(&output, &q_info, RoundingScheme::NearestTiesAwayFromZero); - assert_eq!(expected, actual); - } - - #[test] - fn test_nnafz_halves() { - // test when the output lands at .5 intervals - let output = vec![-3, -2, -1, 0, 1, 2, 3]; - let q_info = BMMQInfo { - input_info: QInfo { - scale: 0.5, - zero_point: 0, - }, - weight_info: QInfo { - scale: 1.0, - zero_point: 0, - }, - output_info: QInfo { - scale: 1.0, - zero_point: 0, - }, - }; - let expected = vec![-2, -1, -1, 0, 1, 1, 2]; - let actual = requantise_fc(&output, &q_info, RoundingScheme::NearestTiesAwayFromZero); - assert_eq!(expected, actual); - } -} diff --git a/common/src/quantization/mod.rs b/common/src/quantization/mod.rs new file mode 100644 index 0000000..19842ed --- /dev/null +++ b/common/src/quantization/mod.rs @@ -0,0 +1,363 @@ +#[cfg(test)] +pub mod tests; + +use ark_std::Zero; + +use crate::model::tensor::Integral; + +const F64_EXPONENT_SHIFT: u64 = 52; +const F64_EXPONENT_BIAS: i32 = 1023; +const F64_EXPONENT_MASK: u64 = 0x7ff0000000000000; +const F64_FRACTION_MASK: u64 = 0x000fffffffffffff; + +// TODO if we decide to make the model generic on the quantisation process +// types (which is probably correct now that the qtypes are generics), these +// will go away +// Type for quantisation scales +pub(crate) type QScaleType = f32; +// Larger precision type to compute the requantization scale in some schemes +pub(crate) type QScaleComputationType = f64; + +pub struct QInfo { + pub scale: QScaleType, + pub zero_point: ST, +} + +// TODO: this will probably change to inference-ready requantization info +// Even what is being done now could be optimised by precomputing outside the +// evaluate function +pub struct BMMQInfo { + pub input_info: QInfo, + pub weight_info: QInfo, + // Bias requantization information is not used (and is indeed directly + // computable from the two above) + pub output_info: QInfo, +} + +// Strategies to requantize the output of a BMM node +#[derive(Debug, Clone, Copy)] +pub enum BMMRequantizationStrategy { + Floating, // Core: multiply the input by the floating-point scale + Reference, // Core: fixed-point-multiply the input by the quantised + // scale, then right-shift further (and round) + SingleRound, // Core: integer-multiply the input by the quantised scale, + // then apply a single right shift (and round) +} + +pub enum RoundingScheme { + NearestTiesAwayFromZero, + NearestTiesEven, +} + +pub fn requantize_fc( + output: &[LT], + q_info: &BMMQInfo, + scheme: RoundingScheme, +) -> Vec +where + ST: Integral + Into, + LT: Integral + TryInto, +{ + match scheme { + RoundingScheme::NearestTiesAwayFromZero => requantize_fc_ntafz::(output, q_info), + RoundingScheme::NearestTiesEven => requantize_fc_nte::(output, q_info), + } +} + +fn requantize_fc_ntafz(output: &[LT], q_info: &BMMQInfo) -> Vec +where + ST: Integral + Into, + LT: Integral + TryInto, +{ + // 1. Computing scale + // TODO In actual schemes, this will be decomposed as (int, shift) + let (s_i, s_w, s_o) = ( + q_info.input_info.scale, + q_info.weight_info.scale, + q_info.output_info.scale, + ); + let (s_i, s_w, s_o) = ( + s_i as QScaleComputationType, + s_w as QScaleComputationType, + s_o as QScaleComputationType, + ); + let s = (s_i * s_w / s_o) as QScaleType; + + // 2. Requantize + // TODO add rayon for parallelisation? + output + .iter() + .map(|x| { + let x = LT::to_qscaletype(x) * s; + let mut x = LT::from_qscaletype(x.round()); + x += q_info.output_info.zero_point.into(); + partial_ord_clamp(x, ST::MIN.into(), ST::MAX.into()) + .try_into() + .map_err(|_| "Unable to convert Large Type to Small Type") + .unwrap() + }) + .collect() +} + +// The (unstable) method clamp comes from the trait Ord, which we cannot +// restrict Integral to as we need f32 to implement the latter. Note that this +// method is not meaningfully defined for classes that genuinely do not +// implement Ord (total order relation) but only PartialOrd (partial order +// relation). +fn partial_ord_clamp(x: T, min: T, max: T) -> T { + if x <= min { + min + } else if x >= max { + max + } else { + x + } +} + +fn requantize_fc_nte(output: &[LT], q_info: &BMMQInfo) -> Vec +where + ST: Integral + Into, + LT: Integral + TryInto, +{ + // 1. Computing scale + // TODO In actual schemes, this will be decomposed as (int, shift) + let (s_i, s_w, s_o) = ( + q_info.input_info.scale, + q_info.weight_info.scale, + q_info.output_info.scale, + ); + let (s_i, s_w, s_o) = ( + s_i as QScaleComputationType, + s_w as QScaleComputationType, + s_o as QScaleComputationType, + ); + let s = (s_i * s_w / s_o) as QScaleType; + + // 2. Requantize + // TODO add rayon for parallelisation? + output + .iter() + .map(|x| { + let x = LT::to_qscaletype(x) * s; + let mut x = LT::from_qscaletype(x.round_ties_even()); // TODO which type to pick here? Should we check for overflows? + x += q_info.output_info.zero_point.into(); + partial_ord_clamp(x, ST::MIN.into(), ST::MAX.into()) + .try_into() + .map_err(|_| "Unable to convert Large Type to Small Type") + .unwrap() + }) + .collect() +} + +// Implementation of TF Lite's reference requantization. +pub fn requantize_ref( + // TODO Think whether we can afford to pass ownership here and change the iter() below by into_iter() + output: &[LT], + effective_multiplier: LT, + effective_shift: usize, + output_zero_point: ST, +) -> Vec +where + ST: Integral + Into, + LT: Integral + TryInto, +{ + // Computing auxiliary constants used for every input + let effective_multiplier: LT::Double = effective_multiplier.into(); + let output_zero_point: LT = output_zero_point.into(); + + // TODO: Add associated constant MAX_PLUS_ONE to Integral. + let pow2_bits_minus_one = LT::ONE_DOUBLE << (LT::BITS - 1); + + // NOTE: Notice that they are independent of the input. Perhaps it is meaningful to turn: + // xt_pow2_bits_minus_one, non_neg_nudge, and neg_nudge + // into associated constants of type LT in order to avoid their recomputation per call? + + // Mask consists of full_shift ones + let mask = (LT::ONE << effective_shift) - LT::ONE; // TODO: may overflow for some exponents + let mask_div2 = mask >> 1; + + // Constants used during nudging + let non_neg_nudge = LT::ONE_DOUBLE << (LT::BITS - 2); + let neg_nudge = LT::ONE_DOUBLE - non_neg_nudge; + + // Requantize + // TODO add rayon for parallelization? + output + .iter() + .map(|x| { + let (is_negative, nudge) = if *x >= LT::ZERO { + (LT::ZERO, non_neg_nudge) + } else { + (LT::ONE, neg_nudge) + }; + + let product = (*x).into() * effective_multiplier; + + let product_high: LT = ((product + nudge) / pow2_bits_minus_one) + .try_into() + .map_err(|_| "Error trying to convert LT::Double to LT") + .unwrap(); + + // assert(full_shift <= 31); + let remainder = product_high & mask; + let threshold = mask_div2 + is_negative; + + let core = (product_high >> effective_shift) + + if remainder > threshold { + LT::ONE + } else { + LT::ZERO + }; + + let shifted = core + output_zero_point; + + partial_ord_clamp(shifted, ST::MIN.into(), ST::MAX.into()) + .try_into() + .map_err(|_| "Unable to convert Large Type to Small Type") + .unwrap() + }) + .collect() +} + +// Implementation of single-rounding requantisation with quantised parameters +pub fn requantize_single_round( + // TODO Think whether we can afford to pass ownership here and change the iter() below by into_iter() + output: &[LT], + effective_multiplier: LT, + full_shift: usize, + output_zero_point: ST, +) -> Vec +where + ST: Integral + Into, + LT: Integral + TryInto, +{ + // Although these parameters could be directly saved into the node (with + // their final desired types), this computation/conversion is essentially + // free. For the first two, storing them in the node with their actual types + // (instead of the types needed for inference) makes the node more + // transparent as far as definitions and proof system goes. TF Lite does the + // same. For the other two, computing them here makes the function more + // usable. + let effective_multiplier: LT::Double = effective_multiplier.into(); + let output_zero_point: LT = output_zero_point.into(); + let non_neg_nudge = LT::ONE_DOUBLE << (full_shift - 1); + let neg_nudge = non_neg_nudge - LT::ONE_DOUBLE; + + // Requantize + // TODO add rayon for parallelization? + output + .iter() + .map(|x| { + let nudge = if *x >= LT::ZERO { + non_neg_nudge + } else { + neg_nudge + }; + + let core: LT = (((*x).into() * effective_multiplier + nudge) >> full_shift) + .try_into() + .map_err(|_| "Error trying to convert LT::Double to LT") + .unwrap(); + + let shifted = core + output_zero_point; + + partial_ord_clamp(shifted, ST::MIN.into(), ST::MAX.into()) + .try_into() + .map_err(|_| "Unable to convert Large Type to Small Type") + .unwrap() + }) + .collect() +} + +// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/quantization_util.cc#L53-L104 +pub(crate) fn quantize_multiplier(double_multiplier: f64) -> (i32, usize) { + if double_multiplier.is_zero() { + return (0, 0); + } + + let (q, expon) = frexp(double_multiplier); + + assert!( + expon <= 0, + "expon should be non-positive. Got: {} instead.", + expon + ); + + // Negate expon to obtain the number of right-shift bits + let mut shift = -expon as usize; + + // TF Lite uses C++'s round function under the hood as can be seen here: + // https://github.com/tensorflow/tensorflow/blob/46f028f94dcd974705cd14e8abf05b9bd8f20bf0/tensorflow/lite/kernels/internal/cppmath.h#L35 + // The function rounds to the nearest integer, breaking ties away from zero. + // The same strategy is implemented in Rust's round method: + // https://doc.rust-lang.org/std/primitive.f64.html#method.round + // See also: https://en.cppreference.com/w/c/numeric/fenv/FE_round + let mut q_fixed = (q * ((1_i64 << (i32::BITS - 1)) as f64)).round() as i64; + + // TFLITE_CHECK(q_fixed <= (1LL << 31)); + assert!( + q_fixed <= 1_i64 << (i32::BITS - 1), + "q_fixed must not exceed 2^{}. Got: {} instead.", + i32::BITS - 1, + q_fixed + ); + + if q_fixed == 1_i64 << (i32::BITS - 1) { + q_fixed /= 2; + shift += 1; + } + + // TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); + assert!( + q_fixed <= i32::MAX as i64, + "q_fixed must not exceed {}. Got: {} instead.", + i32::MAX, + q_fixed + ); + + // If exponent is too small. + if expon < -((i32::BITS - 1) as isize) { + shift = 0; + q_fixed = 0; + } + + (q_fixed as i32, shift) +} + +// This function returns the normalized fraction and exponent of a double-precision number x. +// If the argument x is not zero, the normalized fraction is x times a power of two, and its +// absolute value is always in the range 0.5 (inclusive) to 1 (exclusive). If x is zero, then +// the normalized fraction and exponent should be zero. However, for our purposes, x should +// always be positive. +fn frexp(x: f64) -> (f64, isize) { + assert!(x > 0.0); + + let x_bits: u64 = x.to_bits(); + + // truncate low order bits to compute the biased exponent + let mut expon = ((x_bits & F64_EXPONENT_MASK) / (1 << F64_EXPONENT_SHIFT)) as i32; + + // assert 0 < expon < 1023<<1 + assert!(expon > 0 && expon < (F64_EXPONENT_BIAS << 1)); + + // unbias exponent + expon = expon - F64_EXPONENT_BIAS + 1; + + let mantissa = x_bits & F64_FRACTION_MASK; + + let q = ((mantissa + (1 << F64_EXPONENT_SHIFT)) as f64) + / ((1_u64 << (F64_EXPONENT_SHIFT + 1)) as f64); + + (q, expon as isize) +} + +// This function is used to quantise model inputs and its types are fixed +pub fn quantise_f32_u8_nne(values: &[f32], scale: QScaleType, zero: u8) -> Vec { + values + .iter() + .map(|x| { + ((((*x as QScaleType) / scale) + (zero as f32)).round_ties_even() as i32) + .clamp(u8::MIN as i32, u8::MAX as i32) as u8 + }) + .collect() +} diff --git a/common/src/quantization/tests.rs b/common/src/quantization/tests.rs new file mode 100644 index 0000000..f63feeb --- /dev/null +++ b/common/src/quantization/tests.rs @@ -0,0 +1,219 @@ +use super::*; + +#[test] +fn test_nnafz_noop() { + let output = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let q_info = BMMQInfo { + input_info: QInfo { + scale: 1.0, + zero_point: 0, + }, + weight_info: QInfo { + scale: 1.0, + zero_point: 0, + }, + output_info: QInfo { + scale: 1.0, + zero_point: 0, + }, + }; + let expected = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let actual = requantize_fc(&output, &q_info, RoundingScheme::NearestTiesAwayFromZero); + assert_eq!(expected, actual); +} + +#[test] +fn test_nnafz_halves() { + // test when the output lands at .5 intervals + let output = vec![-3, -2, -1, 0, 1, 2, 3]; + let q_info = BMMQInfo { + input_info: QInfo { + scale: 0.5, + zero_point: 0, + }, + weight_info: QInfo { + scale: 1.0, + zero_point: 0, + }, + output_info: QInfo { + scale: 1.0, + zero_point: 0, + }, + }; + let expected = vec![-2, -1, -1, 0, 1, 1, 2]; + let actual = requantize_fc(&output, &q_info, RoundingScheme::NearestTiesAwayFromZero); + assert_eq!(expected, actual); +} + +#[test] +fn test_frexp_positive_expon() { + let num = 1.0; + let expected = (0.5, 1); + let actual = frexp(num); + assert_eq!(expected, actual); +} + +#[test] +fn test_frexp_negative_expon() { + let num = 0.1; + let expected = (0.8, -3); + let actual = frexp(num); + assert_eq!(expected, actual); +} + +#[test] +fn test_frexp_zero_expon() { + let num = 1.0 / 2.0 + 1.0 / ((1_i64 << 53) as f64); + let expected = 0; + let actual = frexp(num).1; + assert_eq!(expected, actual); +} + +#[test] +fn test_quantize_multiplier_zero() { + let double_multiplier = 0_f64; + let expected = (0, 0); + let actual = quantize_multiplier(double_multiplier); + assert_eq!(actual, expected); +} + +#[test] +fn test_quantize_multiplier_zero_expon() { + let double_multiplier = 1.0 / 2.0 + 1.0 / ((1_i64 << 53) as f64); + let expected = (1_073_741_824, 0); + let actual = quantize_multiplier(double_multiplier); + assert_eq!(expected, actual); +} + +#[test] +fn test_quantize_multiplier_negative_expon() { + let double_multiplier = 0.1; + let expected = (1_717_986_918, 3); + let actual = quantize_multiplier(double_multiplier); + assert_eq!(expected, actual); +} + +#[test] +#[should_panic(expected = "expon should be non-positive.")] +fn test_ref_noop() { + let (s_i, s_w, s_o) = (1.0, 1.0, 1.0); + let double_mul = s_i * s_w / s_o; + + // panics because 1.0 = 0.5 * 2^1 and the exponent is positive. + let _ = quantize_multiplier(double_mul); +} + +#[test] +fn test_ref_specific() { + let double_mul = 0.0003099559683924777; + let (effective_mul, effective_shift) = quantize_multiplier(double_mul); + let output_zero_point = 0; + let output = vec![0, 1, 2, 3, 4, 5, 6, i32::MAX]; + + let expected = vec![0, 0, 0, 0, 0, 0, 0, 665625]; + let actual = requantize_ref(&output, effective_mul, effective_shift, output_zero_point); + assert_eq!(expected, actual); +} + +#[test] +fn test_single_specific() { + let double_mul = 0.0003099559683924777; + let (effective_mul, effective_shift) = quantize_multiplier(double_mul); + let output_zero_point = 0; + let output = vec![0, 1, 2, 3, 4, 5, 6, i32::MAX]; + + let full_shift = effective_shift + (i32::BITS - 1) as usize; + + let expected = vec![0, 0, 0, 0, 0, 0, 0, 665625]; + let actual = requantize_single_round(&output, effective_mul, full_shift, output_zero_point); + assert_eq!(expected, actual); +} + +#[test] +fn compare_three_requantizations() { + let x = &[0, 1234, -1234, 12345, -12345, 123456, -123456]; + + let bmm_req_info: BMMQInfo = BMMQInfo { + input_info: QInfo { + scale: 0.003921568859368563, + zero_point: -128, + }, + weight_info: QInfo { + scale: 0.012436429969966412, + zero_point: 0, + }, + output_info: QInfo { + scale: 0.1573459506034851, + zero_point: 47, + }, + }; + + let double_mul = (bmm_req_info.input_info.scale as f64) + * (bmm_req_info.weight_info.scale as f64) + / (bmm_req_info.output_info.scale as f64); + let (effective_mul, effective_shift) = quantize_multiplier(double_mul); + let full_shift = effective_shift + (i32::BITS - 1) as usize; + + let req_float = requantize_fc_ntafz::(x, &bmm_req_info); + let req_ref = requantize_ref::( + x, + effective_mul, + effective_shift, + bmm_req_info.output_info.zero_point, + ); + let req_single = requantize_single_round::( + x, + effective_mul, + full_shift, + bmm_req_info.output_info.zero_point, + ); + + assert_eq!(req_float, req_ref); + assert_eq!(req_float, req_single); +} + +#[test] +fn compare_req_float_and_single() { + let x = &(-8..=8).collect::>(); + + let bmm_req_info: BMMQInfo = BMMQInfo { + input_info: QInfo { + scale: 0.25, + zero_point: 0, + }, + weight_info: QInfo { + scale: 1.0, + zero_point: 0, + }, + output_info: QInfo { + scale: 1.0, + zero_point: 0, + }, + }; + + let double_mul = (bmm_req_info.input_info.scale as f64) + * (bmm_req_info.weight_info.scale as f64) + / (bmm_req_info.output_info.scale as f64); + let (effective_mul, effective_shift) = quantize_multiplier(double_mul); + let full_shift = effective_shift + (i32::BITS - 1) as usize; + + let req_float = requantize_fc_ntafz::(x, &bmm_req_info); + let req_single = requantize_single_round::( + x, + effective_mul, + full_shift, + bmm_req_info.output_info.zero_point, + ); + + assert_eq!(req_float, req_single); + + // N.B.: The following fails in one pathological case, as expected: the + // composition of two nudges inflates the final result by one too much + // let req_ref = requantize_ref::( + // x, + // effective_mul, + // effective_shift, + // bmm_req_info.output_info.zero_point, + // ); + // assert_eq!(req_float, req_ref); +} diff --git a/common/src/utils/mod.rs b/common/src/utils/mod.rs index 6993090..818e218 100644 --- a/common/src/utils/mod.rs +++ b/common/src/utils/mod.rs @@ -1,16 +1,53 @@ +use crate::{ + model::nodes::{ + requantize_bmm_float::RequantizeBMMFloatNode, requantize_bmm_ref::RequantizeBMMRefNode, + requantize_bmm_single::RequantizeBMMSingleNode, + }, + BMMRequantizationStrategy, Node, +}; + #[cfg(feature = "test-types")] pub mod pcs_types; #[cfg(feature = "test-types")] pub mod test_sponge; +// Convenience function to create a requantization Node variant depending +// on a chosen strategy. Only implemented for ST = i8, since the +// constructor of the reference implmementation (and therefore of single-round +// too) is only defined in this case +pub(crate) fn req_bmm_from_strategy( + req_strategy: BMMRequantizationStrategy, + inter_dim: usize, + s_i: f32, + z_i: i8, + s_w: f32, + z_w: i8, + s_o: f32, + z_o: i8, +) -> Node { + match req_strategy { + BMMRequantizationStrategy::Floating => Node::RequantizeBMMFloat( + RequantizeBMMFloatNode::new(inter_dim, s_i, z_i, s_w, z_w, s_o, z_o), + ), + BMMRequantizationStrategy::Reference => { + Node::RequantizeBMMRef(RequantizeBMMRefNode::new(inter_dim, s_i, s_w, s_o, z_o)) + } + BMMRequantizationStrategy::SingleRound => { + Node::RequantizeBMMSingle(RequantizeBMMSingleNode::new(inter_dim, s_i, s_w, s_o, z_o)) + } + } +} + macro_rules! node_op { ($self:expr, $method:ident, $trait:ident) => { match $self { Node::BMM(node) => node.$method(), - Node::RequantiseBMM(node) => node.$method(), + Node::RequantizeBMMFloat(node) => node.$method(), + Node::RequantizeBMMRef(node) => node.$method(), + Node::RequantizeBMMSingle(node) => node.$method(), Node::ReLU(node) => node.$method(), - Node::Reshape(node) => $trait::::$method(node), + Node::Reshape(node) => $trait::::$method(node), } }; } diff --git a/exploring_tf_lite/example_models/q_model_wrapper.py b/exploring_tf_lite/example_models/q_model_wrapper.py index b0a9b9d..50fa8d6 100644 --- a/exploring_tf_lite/example_models/q_model_wrapper.py +++ b/exploring_tf_lite/example_models/q_model_wrapper.py @@ -101,19 +101,19 @@ def get_output(self, input_data: np.ndarray) -> np.ndarray: self.quantized_model.get_output_details()[0]['index'] ) - def save_params_as_qarray(self, path: str) -> None: + def save_params_as_verifiaml_tensor(self, path: str) -> None: print(f"Saving quantized model parameters to {path}") for key, value in self.get_model_parameters().items(): with open(path + f'/{key}.json', 'w') as f: - json.dump(QModelWrapper.as_qarray(value), f) + json.dump(QModelWrapper.as_verifiaml_tensor(value), f) @staticmethod - def as_qarray(param: np.ndarray) -> Dict[str, Any]: + def as_verifiaml_tensor(param: np.ndarray) -> Dict[str, Any]: flattened_param = QModelWrapper.multi_flatten(param.tolist()) - qarray_shape = list(reversed(param.shape)) - cumulative_dims = [prod(qarray_shape[i+1:]) for i in range(len(qarray_shape))] + tensor_shape = list(reversed(param.shape)) + cumulative_dims = [prod(tensor_shape[i+1:]) for i in range(len(tensor_shape))] - return OrderedDict([("f", flattened_param), ("s", qarray_shape), ("c", cumulative_dims)]) + return OrderedDict([("f", flattened_param), ("s", tensor_shape), ("c", cumulative_dims)]) @staticmethod def multi_flatten(x: Any) -> List[Union[int, float]]: diff --git a/exploring_tf_lite/exploring.ipynb b/exploring_tf_lite/exploring.ipynb index 6178fd0..00670e8 100644 --- a/exploring_tf_lite/exploring.ipynb +++ b/exploring_tf_lite/exploring.ipynb @@ -681,7 +681,7 @@ "# TODO one could wrap this in type checks for good measure (one per tensor, not per element)\n", "ROUNDING = round_nearest_half_up\n", "\n", - "# def requantise_half_away_from_zero(x):\n", + "# def requantize_half_away_from_zero(x):\n", "# # TODO control overflows here?\n", "# abs_a_s_int = np.abs(x) * S_UINT\n", "# rounding_bit = (abs_a_s_int >> (S_SHIFT - 1)) & 1\n", @@ -690,11 +690,11 @@ "# return np.sign(x) * (sh + rounding_bit)\n", "\n", "# TODO there's probably a more elegant way to do this\n", - "def requantise(x):\n", + "def requantize(x):\n", " # TODO control overflows here or in the ROUNDING function?\n", " return ROUNDING(x * S_UINT, S_SHIFT)\n", "\n", - "# requantise_tensor = np.vectorize(requantise)" + "# requantize_tensor = np.vectorize(requantize)" ] }, { @@ -816,7 +816,7 @@ " # TODO handle S_EXPONENT == 0\n", " return (nudged + (1 << (-S_REL_SHIFT - 1))) >> -S_REL_SHIFT\n", "\n", - "requantise_tensor = np.vectorize(arm_requantize)" + "requantize_tensor = np.vectorize(arm_requantize)" ] }, { @@ -831,7 +831,7 @@ " x = np.matmul(x, W_32) + B_32\n", "\n", " # this is the correct, specification-exact way to do it; in the 10000 sample images, it always coincides with np.rint(x * S)\n", - " x = requantise_tensor(x)\n", + " x = requantize_tensor(x)\n", " \n", " x = np.clip(x + Z_O, -128, 127)\n", " x = (x + 128).astype(np.uint8)\n", @@ -964,7 +964,7 @@ "x1 = (ip.reshape(RESHAPE).astype(np.int32) - 128).astype(np.int8)\n", "x2 = x1.astype(np.int32) - Z_I\n", "x3 = np.matmul(x2, W_32) + B_32\n", - "x4 = requantise_tensor(x3).astype(np.int32)\n", + "x4 = requantize_tensor(x3).astype(np.int32)\n", "x5 = np.clip(x4 + Z_O, -128, 127)\n", "x6 = (x5 + 128).astype(np.uint8)" ] diff --git a/exploring_tf_lite/two_layer_perceptron.ipynb b/exploring_tf_lite/two_layer_perceptron.ipynb index 5fdba48..f74fc7a 100644 --- a/exploring_tf_lite/two_layer_perceptron.ipynb +++ b/exploring_tf_lite/two_layer_perceptron.ipynb @@ -2,9 +2,20 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import os\n", "from math import log2, ceil, floor\n", @@ -96,9 +107,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n", + "\u001b[1m11490434/11490434\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 0us/step\n" + ] + } + ], "source": [ "mnist = tf.keras.datasets.mnist\n", "\n", @@ -358,7 +378,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/prover/benches/bmm.rs b/prover/benches/bmm.rs index daa112d..1984d5d 100644 --- a/prover/benches/bmm.rs +++ b/prover/benches/bmm.rs @@ -8,7 +8,7 @@ use ark_std::test_rng; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use hcs_common::{ python::*, quantise_f32_u8_nne, test_sponge, BMMNode, Ligero, Model, Node, NodeCommitment, - NodeCommitmentState, Poly, QArray, RequantiseBMMNode, + NodeCommitmentState, Poly, RequantizeBMMFloatNode, Tensor, }; use hcs_prover::ProveModel; use hcs_verifier::VerifyModel; @@ -35,27 +35,28 @@ macro_rules! PATH { }; } -fn build_fully_connected_layer_mnist(resize_factor: usize) -> Model +fn build_fully_connected_layer_mnist(resize_factor: usize) -> Model where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - let w_array: QArray = QArray::read(&format!(PATH!(), "weights.json")); - let b_array: QArray = QArray::read(&format!(PATH!(), "bias.json")); + let w_array: Tensor = Tensor::read(&format!(PATH!(), "weights.json")); + let b_array: Tensor = Tensor::read(&format!(PATH!(), "bias.json")); - let bmm: BMMNode = BMMNode::new(w_array, b_array, Z_I); + let bmm: BMMNode = BMMNode::new(w_array, b_array, Z_I); - let req_bmm: RequantiseBMMNode = RequantiseBMMNode::new(10, S_I, Z_I, S_W, Z_W, S_O, Z_O); + let req_bmm: RequantizeBMMFloatNode = + RequantizeBMMFloatNode::new(10, S_I, Z_I, S_W, Z_W, S_O, Z_O); Model::new( vec![resize_factor as usize * 28 * 28], - vec![Node::BMM(bmm), Node::RequantiseBMM(req_bmm)], + vec![Node::BMM(bmm), Node::RequantizeBMMFloat(req_bmm)], ) } -fn quantise_input(raw_input: &QArray) -> QArray { - let quantised_input: QArray = QArray::new( +fn quantise_input(raw_input: &Tensor) -> Tensor { + let quantised_input: Tensor = Tensor::new( quantise_f32_u8_nne(raw_input.values(), S_INPUT, Z_INPUT), raw_input.shape().clone(), ); @@ -114,7 +115,7 @@ fn bench_tf_inference(c: &mut Criterion, resize_factor: usize, args: Vec<(&str, Python::with_gil(|py| { let model = get_model(py, "QFullyConnectedLayer", Some(args.clone())); - save_model_parameters_as_qarray(py, &model, &format!(PATH!(), "")); + save_model_parameters_as_tensor(py, &model, &format!(PATH!(), "")); group.bench_function( BenchmarkId::new( "inference", @@ -127,8 +128,8 @@ fn bench_tf_inference(c: &mut Criterion, resize_factor: usize, args: Vec<(&str, fn bench_verifiaml_inference( c: &mut Criterion, - model: &Model, - raw_input: &QArray, + model: &Model, + raw_input: &Tensor, resize_factor: usize, ) { let mut group = c.benchmark_group("verifiaml"); @@ -147,8 +148,8 @@ fn bench_verifiaml_inference( fn bench_verifiaml_proof( c: &mut Criterion, - model: &Model, - raw_input: &QArray, + model: &Model, + raw_input: &Tensor, ck: &PCS::CommitterKey, sponge: &mut S, resize_factor: usize, @@ -196,12 +197,12 @@ where fn bench_verifiaml_verification( c: &mut Criterion, - model: &Model, + model: &Model, ck: &PCS::CommitterKey, vk: &PCS::VerifierKey, node_coms: &Vec>, node_com_states: &Vec>, - raw_input: &QArray, + raw_input: &Tensor, sponge: &mut S, resize_factor: usize, ) where diff --git a/prover/examples/common/lib.rs b/prover/examples/common/lib.rs index c16075f..c2ace30 100644 --- a/prover/examples/common/lib.rs +++ b/prover/examples/common/lib.rs @@ -1,4 +1,4 @@ -use hcs_common::{quantise_f32_u8_nne, InferenceProof, Model, Poly, QArray}; +use hcs_common::{quantise_f32_u8_nne, InferenceProof, Model, Poly, Tensor}; use hcs_prover::ProveModel; use hcs_verifier::VerifyModel; @@ -11,7 +11,7 @@ use ark_std::test_rng; pub fn prove_inference( input_path: &str, expected_output_path: &str, - model: &Model, + model: &Model, qinfo: (f32, u8), sponge: S, output_shape: Vec, @@ -20,10 +20,10 @@ pub fn prove_inference( S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - let input: QArray = QArray::read(input_path); - let expected_output: QArray = QArray::read(expected_output_path); + let input: Tensor = Tensor::read(input_path); + let expected_output: Tensor = Tensor::read(expected_output_path); - let quantised_input: QArray = QArray::new( + let quantised_input: Tensor = Tensor::new( quantise_f32_u8_nne(input.values(), qinfo.0, qinfo.1), input.shape().clone(), ); @@ -38,7 +38,7 @@ pub fn prove_inference( let (node_coms, node_com_states): (Vec<_>, Vec<_>) = model.commit(&ck, None).into_iter().unzip(); - let inference_proof: InferenceProof = model.prove_inference( + let inference_proof: InferenceProof = model.prove_inference( &ck, Some(&mut rng), &mut sponge, @@ -51,7 +51,7 @@ pub fn prove_inference( let output_i8 = output_qtypearray.unwrap_small(); - let output_u8: QArray = (output_i8.cast::() + 128).cast(); + let output_u8: Tensor = (output_i8.cast::() + 128).cast(); assert_eq!(output_u8.compact_resize(output_shape, 0), expected_output); @@ -61,7 +61,7 @@ pub fn prove_inference( pub fn verify_inference( input_path: &str, expected_output_path: &str, - model: &Model, + model: &Model, qinfo: (f32, u8), sponge: S, output_shape: Vec, @@ -70,10 +70,10 @@ pub fn verify_inference( S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - let input: QArray = QArray::read(input_path); - let expected_output: QArray = QArray::read(expected_output_path); + let input: Tensor = Tensor::read(input_path); + let expected_output: Tensor = Tensor::read(expected_output_path); - let quantised_input: QArray = QArray::new( + let quantised_input: Tensor = Tensor::new( quantise_f32_u8_nne(input.values(), qinfo.0, qinfo.1), input.shape().clone(), ); @@ -91,7 +91,7 @@ pub fn verify_inference( let (node_coms, node_com_states): (Vec<_>, Vec<_>) = model.commit(&ck, None).into_iter().unzip(); - let inference_proof: InferenceProof = model.prove_inference( + let inference_proof: InferenceProof = model.prove_inference( &ck, Some(&mut rng), &mut proving_sponge, diff --git a/prover/examples/simple_perceptron_mnist/main.rs b/prover/examples/simple_perceptron_mnist/main.rs index 01d0395..6a1e88e 100644 --- a/prover/examples/simple_perceptron_mnist/main.rs +++ b/prover/examples/simple_perceptron_mnist/main.rs @@ -1,6 +1,6 @@ use hcs_common::{ simple_perceptron_mnist::{build_simple_perceptron_mnist, parameters::*, OUTPUT_DIM}, - test_sponge, Ligero, + test_sponge, BMMRequantizationStrategy, Ligero, }; use ark_bn254::Fr; @@ -17,7 +17,9 @@ macro_rules! PATH { } fn main() { - let simple_perceptron = build_simple_perceptron_mnist::, Ligero>(); + let simple_perceptron = build_simple_perceptron_mnist::, Ligero>( + BMMRequantizationStrategy::Floating, + ); // Right now this can't be QInfo because the latter is always a pair // (f32, i8), which indeed matches in-model quantisation, but not diff --git a/prover/examples/two_layer_perceptron_mnist/main.rs b/prover/examples/two_layer_perceptron_mnist/main.rs index 2e424bb..76efc0f 100644 --- a/prover/examples/two_layer_perceptron_mnist/main.rs +++ b/prover/examples/two_layer_perceptron_mnist/main.rs @@ -1,7 +1,7 @@ use hcs_common::{ test_sponge, two_layer_perceptron_mnist::{build_two_layer_perceptron_mnist, parameters::*, OUTPUT_DIM}, - Ligero, + BMMRequantizationStrategy, Ligero, }; use ark_bn254::Fr; @@ -18,8 +18,9 @@ macro_rules! PATH { } fn main() { - let two_layer_perceptron = - build_two_layer_perceptron_mnist::, Ligero>(); + let two_layer_perceptron = build_two_layer_perceptron_mnist::, Ligero>( + BMMRequantizationStrategy::Floating, + ); // Right now this can't be QInfo because the latter is always a pair // (f32, i8), which indeed matches in-model quantisation, but not diff --git a/prover/src/lib.rs b/prover/src/lib.rs index e807cc5..07d3588 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -4,7 +4,7 @@ use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; use ark_std::rand::RngCore; use hcs_common::{ - InnerType, LabeledPoly, Node, NodeCommitment, NodeCommitmentState, NodeProof, Poly, + LabeledPoly, Node, NodeCommitment, NodeCommitmentState, NodeProof, Poly, SmallNIO, }; mod model; @@ -15,7 +15,7 @@ mod util; pub use model::ProveModel; /// SNARK-specific operations that each node must implement. -pub trait NodeOpsProve +pub trait NodeOpsProve where F: PrimeField + Absorb, S: CryptographicSponge, @@ -44,13 +44,12 @@ where ) -> (NodeCommitment, NodeCommitmentState); } -impl NodeOpsProve for Node +impl NodeOpsProve for Node where - F: PrimeField + Absorb + From + From + From, + F: PrimeField + Absorb + From + From, S: CryptographicSponge, PCS: PolynomialCommitment, S>, - I: InnerType + TryFrom, - O: InnerType + From, + ST: SmallNIO, { fn prove( &self, diff --git a/prover/src/model.rs b/prover/src/model.rs index ffe7542..c32270e 100644 --- a/prover/src/model.rs +++ b/prover/src/model.rs @@ -4,17 +4,19 @@ use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; use ark_poly::MultilinearExtension; use ark_poly_commit::{LabeledPolynomial, PolynomialCommitment}; -use hcs_common::{InferenceProof, InnerType, Model}; -use hcs_common::{NodeCommitment, NodeCommitmentState, Poly, QArray, QTypeArray}; +use hcs_common::{ + InferenceProof, Model, NIOTensor, NodeCommitment, NodeCommitmentState, Poly, SmallNIO, Tensor, +}; use crate::NodeOpsProve; -pub trait ProveModel +pub trait ProveModel where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, + ST: SmallNIO, { - fn padded_evaluate(&self, input: QArray) -> QArray; + fn padded_evaluate(&self, input: Tensor) -> Tensor; fn prove_inference( &self, @@ -23,8 +25,8 @@ where sponge: &mut S, node_coms: &Vec>, node_com_states: &Vec>, - input: QArray, - ) -> InferenceProof; + input: Tensor, + ) -> InferenceProof; fn commit( &self, @@ -33,17 +35,16 @@ where ) -> Vec<(NodeCommitment, NodeCommitmentState)>; } -impl ProveModel for Model +impl ProveModel for Model where - F: PrimeField + Absorb + From + From, + F: PrimeField + Absorb + From + From, S: CryptographicSponge, PCS: PolynomialCommitment, S>, - ST: InnerType + TryFrom, - LT: InnerType + From, + ST: SmallNIO, { /// Unlike the node's `padded_evaluate`, the model's `padded_evaluate` accepts unpadded input /// and first re-sizes it before running inference. - fn padded_evaluate(&self, input: QArray) -> QArray { + fn padded_evaluate(&self, input: Tensor) -> Tensor { // TODO sanity check: input shape matches model input shape let input = input.compact_resize( @@ -55,7 +56,7 @@ where ST::ZERO, ); - let mut output = QTypeArray::S(input); + let mut output = NIOTensor::S(input); for node in &self.nodes { output = node.padded_evaluate(&output); @@ -74,8 +75,8 @@ where sponge: &mut S, node_coms: &Vec>, node_com_states: &Vec>, - input: QArray, - ) -> InferenceProof { + input: Tensor, + ) -> InferenceProof { // TODO Absorb public parameters into s (to be determined what exactly) let output = input.compact_resize( @@ -88,7 +89,7 @@ where let output_f: Vec = output.values().iter().map(|x| F::from(*x)).collect(); - let mut output = QTypeArray::S(output); + let mut output = NIOTensor::S(output); // First pass: computing node values // TODO handling F and QSmallType is inelegant; we might want to switch @@ -103,8 +104,8 @@ where output = node.padded_evaluate(&output); let output_f: Vec = match &output { - QTypeArray::S(o) => o.values().iter().map(|x| F::from(*x)).collect(), - QTypeArray::L(o) => o.values().iter().map(|x| F::from(*x)).collect(), + NIOTensor::S(o) => o.values().iter().map(|x| F::from(*x)).collect(), + NIOTensor::L(o) => o.values().iter().map(|x| F::from(*x)).collect(), }; node_outputs.push(output.clone()); diff --git a/prover/src/nodes/bmm.rs b/prover/src/nodes/bmm.rs index ea59343..1ebdcf1 100644 --- a/prover/src/nodes/bmm.rs +++ b/prover/src/nodes/bmm.rs @@ -8,19 +8,18 @@ use ark_std::rand::RngCore; use ark_sumcheck::ml_sumcheck::{protocol::ListOfProductsOfPolynomials, MLSumcheck}; use hcs_common::{ - BMMNode, BMMNodeCommitment, BMMNodeCommitmentState, BMMNodeProof, InnerType, LabeledPoly, - NodeCommitment, NodeCommitmentState, NodeOpsPadded, NodeProof, Poly, + BMMNode, BMMNodeCommitment, BMMNodeCommitmentState, BMMNodeProof, LabeledPoly, NodeCommitment, + NodeCommitmentState, NodeOpsPadded, NodeProof, Poly, SmallNIO, }; use crate::NodeOpsProve; -impl NodeOpsProve for BMMNode +impl NodeOpsProve for BMMNode where - F: PrimeField + Absorb + From + From, + F: PrimeField + Absorb + From + From, S: CryptographicSponge, PCS: PolynomialCommitment, S>, - ST: InnerType + TryFrom, - LT: InnerType + From, + ST: SmallNIO, { fn prove( &self, diff --git a/prover/src/nodes/mod.rs b/prover/src/nodes/mod.rs index 8082742..3cef590 100644 --- a/prover/src/nodes/mod.rs +++ b/prover/src/nodes/mod.rs @@ -1,4 +1,4 @@ mod bmm; mod relu; -mod requantise_bmm; +mod requantize_bmm_float; mod reshape; diff --git a/prover/src/nodes/relu.rs b/prover/src/nodes/relu.rs index 523ed8d..450c446 100644 --- a/prover/src/nodes/relu.rs +++ b/prover/src/nodes/relu.rs @@ -4,17 +4,17 @@ use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; use ark_std::rand::RngCore; use hcs_common::{ - InnerType, LabeledPoly, NodeCommitment, NodeCommitmentState, NodeProof, Poly, ReLUNode, + LabeledPoly, NodeCommitment, NodeCommitmentState, NodeProof, Poly, ReLUNode, SmallNIO, }; use crate::NodeOpsProve; -impl NodeOpsProve for ReLUNode +impl NodeOpsProve for ReLUNode where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, - ST: InnerType, + ST: SmallNIO, { fn prove( &self, diff --git a/prover/src/nodes/requantise_bmm.rs b/prover/src/nodes/requantize_bmm_float.rs similarity index 67% rename from prover/src/nodes/requantise_bmm.rs rename to prover/src/nodes/requantize_bmm_float.rs index f26b2bd..043d5fa 100644 --- a/prover/src/nodes/requantise_bmm.rs +++ b/prover/src/nodes/requantize_bmm_float.rs @@ -4,20 +4,19 @@ use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; use ark_std::rand::RngCore; use hcs_common::{ - InnerType, LabeledPoly, NodeCommitment, NodeCommitmentState, NodeProof, Poly, - RequantiseBMMNode, RequantiseBMMNodeCommitment, RequantiseBMMNodeCommitmentState, - RequantiseBMMNodeProof, + LabeledPoly, NodeCommitment, NodeCommitmentState, NodeProof, Poly, RequantizeBMMFloatNode, + RequantizeBMMNodeCommitment, RequantizeBMMNodeCommitmentState, RequantizeBMMNodeProof, + SmallNIO, }; use crate::NodeOpsProve; -impl NodeOpsProve for RequantiseBMMNode +impl NodeOpsProve for RequantizeBMMFloatNode where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, - ST: InnerType + TryFrom, - LT: InnerType + From, + ST: SmallNIO, { fn prove( &self, @@ -32,7 +31,7 @@ where _output_com: &LabeledCommitment, _output_com_state: &PCS::CommitmentState, ) -> NodeProof { - NodeProof::RequantiseBMM(RequantiseBMMNodeProof {}) + NodeProof::RequantizeBMM(RequantizeBMMNodeProof {}) } fn commit( @@ -41,8 +40,8 @@ where _rng: Option<&mut dyn RngCore>, ) -> (NodeCommitment, NodeCommitmentState) { ( - NodeCommitment::RequantiseBMM(RequantiseBMMNodeCommitment()), - NodeCommitmentState::RequantiseBMM(RequantiseBMMNodeCommitmentState()), + NodeCommitment::RequantizeBMM(RequantizeBMMNodeCommitment()), + NodeCommitmentState::RequantizeBMM(RequantizeBMMNodeCommitmentState()), ) } } diff --git a/prover/src/nodes/reshape.rs b/prover/src/nodes/reshape.rs index 6321488..4d2ebc1 100644 --- a/prover/src/nodes/reshape.rs +++ b/prover/src/nodes/reshape.rs @@ -3,18 +3,15 @@ use ark_ff::PrimeField; use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; use ark_std::rand::RngCore; -use hcs_common::{ - InnerType, LabeledPoly, NodeCommitment, NodeCommitmentState, NodeProof, Poly, ReshapeNode, -}; +use hcs_common::{LabeledPoly, NodeCommitment, NodeCommitmentState, NodeProof, Poly, ReshapeNode}; use crate::NodeOpsProve; -impl NodeOpsProve for ReshapeNode +impl NodeOpsProve for ReshapeNode where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, - ST: InnerType, { fn prove( &self, diff --git a/prover/src/util.rs b/prover/src/util.rs index 57ed93c..89f76cf 100644 --- a/prover/src/util.rs +++ b/prover/src/util.rs @@ -2,9 +2,12 @@ macro_rules! node_operation { ($self:expr, $method:ident, $($arg:expr),*) => { match $self { Node::BMM(node) => node.$method($($arg),*), - Node::RequantiseBMM(node) => node.$method($($arg),*), + Node::RequantizeBMMFloat(node) => node.$method($($arg),*), + // TODO add Node::RequantizeBMMRef(node) => node.$method($($arg),*), once the latter implements commit, proof + Node::RequantizeBMMRef(_) => unimplemented!(), + Node::RequantizeBMMSingle(_) => unimplemented!(), Node::ReLU(node) => node.$method($($arg),*), - Node::Reshape(node) => NodeOpsProve::<_, _, _, I, _>::$method(node, $($arg),*), + Node::Reshape(node) => NodeOpsProve::<_, _, _>::$method(node, $($arg),*), } }; } diff --git a/verifier/src/lib.rs b/verifier/src/lib.rs index 8b711c4..f76f8fa 100644 --- a/verifier/src/lib.rs +++ b/verifier/src/lib.rs @@ -2,7 +2,7 @@ use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; -use hcs_common::{InnerType, Node, NodeCommitment, NodeProof, Poly}; +use hcs_common::{Node, NodeCommitment, NodeProof, Poly, SmallNIO}; mod model; mod nodes; @@ -26,13 +26,12 @@ where ) -> bool; } -impl NodeOpsVerify for Node +impl NodeOpsVerify for Node where F: PrimeField + Absorb + From, S: CryptographicSponge, PCS: PolynomialCommitment, S>, - ST: InnerType + TryFrom, - LT: InnerType + From, + ST: SmallNIO, { fn verify( &self, @@ -47,17 +46,20 @@ where } } -fn node_as_node_ops_snark(node: &Node) -> &dyn NodeOpsVerify +fn node_as_node_ops_snark(node: &Node) -> &dyn NodeOpsVerify where F: PrimeField + Absorb + From, S: CryptographicSponge, PCS: PolynomialCommitment, S>, - ST: InnerType, + ST: SmallNIO, { match node { Node::BMM(fc) => fc, - Node::RequantiseBMM(r) => r, + Node::RequantizeBMMFloat(r) => r, Node::ReLU(r) => r, Node::Reshape(r) => r, + // TODO add Node::RequantizeBMMRef(r) => r, once the latter implements NodeOpsVerify + Node::RequantizeBMMRef(_) => unimplemented!(), + Node::RequantizeBMMSingle(_) => unimplemented!(), } } diff --git a/verifier/src/model.rs b/verifier/src/model.rs index 45ea304..1a1c4e3 100644 --- a/verifier/src/model.rs +++ b/verifier/src/model.rs @@ -5,37 +5,37 @@ use ark_poly::Polynomial; use ark_poly_commit::PolynomialCommitment; use ark_std::log2; -use hcs_common::{InferenceProof, InnerType, Model, NodeCommitment, Poly}; +use hcs_common::{InferenceProof, Model, NodeCommitment, Poly, SmallNIO}; -pub trait VerifyModel +pub trait VerifyModel where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, + ST: SmallNIO, { fn verify_inference( &self, vk: &PCS::VerifierKey, sponge: &mut S, node_commitments: &Vec>, - inference_proof: InferenceProof, + inference_proof: InferenceProof, ) -> bool; } -impl VerifyModel for Model +impl VerifyModel for Model where - F: PrimeField + Absorb + From + From, + F: PrimeField + Absorb + From + From, S: CryptographicSponge, PCS: PolynomialCommitment, S>, - ST: InnerType + TryFrom, - LT: InnerType + From, + ST: SmallNIO, { fn verify_inference( &self, vk: &PCS::VerifierKey, sponge: &mut S, node_commitments: &Vec>, - inference_proof: InferenceProof, + inference_proof: InferenceProof, ) -> bool { let InferenceProof { inputs, @@ -70,8 +70,8 @@ where // output nodes and instead working witht their plain values all along, // but that would require messy node-by-node handling let input_node_com = node_value_commitments.first().unwrap(); - let input_node_qarray = inputs[0].ref_small(); - let input_node_f: Vec = input_node_qarray + let input_node_tensor = inputs[0].ref_small(); + let input_node_f: Vec = input_node_tensor .values() .iter() .map(|x| F::from(*x)) @@ -96,12 +96,12 @@ where sponge.squeeze_field_elements(log2(output_node_f.len()) as usize); // Verifying that the actual input was honestly padded with zeros - let padded_input_shape = input_node_qarray.shape().clone(); - let honestly_padded_input = input_node_qarray + let padded_input_shape = input_node_tensor.shape().clone(); + let honestly_padded_input = input_node_tensor .compact_resize(self.input_shape().clone(), ST::ZERO) .compact_resize(padded_input_shape, ST::ZERO); - if honestly_padded_input.values() != input_node_qarray.values() { + if honestly_padded_input.values() != input_node_tensor.values() { return false; } diff --git a/verifier/src/nodes/bmm.rs b/verifier/src/nodes/bmm.rs index 79f2cb1..15c81e5 100644 --- a/verifier/src/nodes/bmm.rs +++ b/verifier/src/nodes/bmm.rs @@ -6,17 +6,17 @@ use ark_sumcheck::ml_sumcheck::{ MLSumcheck, }; use hcs_common::{ - BMMNode, BMMNodeCommitment, BMMNodeProof, InnerType, NodeCommitment, NodeProof, Poly, + BMMNode, BMMNodeCommitment, BMMNodeProof, NodeCommitment, NodeProof, Poly, SmallNIO, }; use crate::NodeOpsVerify; -impl NodeOpsVerify for BMMNode +impl NodeOpsVerify for BMMNode where F: PrimeField + Absorb + From, S: CryptographicSponge, PCS: PolynomialCommitment, S>, - ST: InnerType, + ST: SmallNIO, { fn verify( &self, diff --git a/verifier/src/nodes/mod.rs b/verifier/src/nodes/mod.rs index 8082742..3cef590 100644 --- a/verifier/src/nodes/mod.rs +++ b/verifier/src/nodes/mod.rs @@ -1,4 +1,4 @@ mod bmm; mod relu; -mod requantise_bmm; +mod requantize_bmm_float; mod reshape; diff --git a/verifier/src/nodes/requantise_bmm.rs b/verifier/src/nodes/requantize_bmm_float.rs similarity index 80% rename from verifier/src/nodes/requantise_bmm.rs rename to verifier/src/nodes/requantize_bmm_float.rs index faefe06..eff1d5a 100644 --- a/verifier/src/nodes/requantise_bmm.rs +++ b/verifier/src/nodes/requantize_bmm_float.rs @@ -1,11 +1,11 @@ use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; -use hcs_common::{NodeCommitment, NodeProof, Poly, RequantiseBMMNode}; +use hcs_common::{NodeCommitment, NodeProof, Poly, RequantizeBMMFloatNode}; use crate::NodeOpsVerify; -impl NodeOpsVerify for RequantiseBMMNode +impl NodeOpsVerify for RequantizeBMMFloatNode where F: PrimeField + Absorb, S: CryptographicSponge,